diff --git a/google/cloud/firestore_v1/_helpers.py b/google/cloud/firestore_v1/_helpers.py index 399bdb066..1fbc1a476 100644 --- a/google/cloud/firestore_v1/_helpers.py +++ b/google/cloud/firestore_v1/_helpers.py @@ -120,6 +120,9 @@ def __ne__(self, other): else: return not equality_val + def __repr__(self): + return f"{type(self).__name__}(latitude={self.latitude}, longitude={self.longitude})" + def verify_path(path, is_collection) -> None: """Verifies that a ``path`` has the correct form. diff --git a/google/cloud/firestore_v1/async_client.py b/google/cloud/firestore_v1/async_client.py index fd016dfe7..efc4a47c0 100644 --- a/google/cloud/firestore_v1/async_client.py +++ b/google/cloud/firestore_v1/async_client.py @@ -56,6 +56,8 @@ from google.cloud.firestore_v1.services.firestore.transports import ( grpc_asyncio as firestore_grpc_transport, ) +from google.cloud.firestore_v1.async_pipeline import AsyncPipeline +from google.cloud.firestore_v1.pipeline_source import PipelineSource if TYPE_CHECKING: # pragma: NO COVER import datetime @@ -438,3 +440,10 @@ def transaction( A transaction attached to this client. """ return AsyncTransaction(self, max_attempts=max_attempts, read_only=read_only) + + @property + def _pipeline_cls(self): + return AsyncPipeline + + def pipeline(self) -> PipelineSource: + return PipelineSource(self) diff --git a/google/cloud/firestore_v1/async_pipeline.py b/google/cloud/firestore_v1/async_pipeline.py new file mode 100644 index 000000000..d476cc283 --- /dev/null +++ b/google/cloud/firestore_v1/async_pipeline.py @@ -0,0 +1,134 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +.. warning:: + **Preview API**: Firestore Pipelines is currently in preview and is + subject to potential breaking changes in future releases +""" + +from __future__ import annotations +from typing import TYPE_CHECKING +from google.cloud.firestore_v1 import pipeline_stages as stages +from google.cloud.firestore_v1.base_pipeline import _BasePipeline +from google.cloud.firestore_v1.pipeline_result import AsyncPipelineStream +from google.cloud.firestore_v1.pipeline_result import PipelineSnapshot +from google.cloud.firestore_v1.pipeline_result import PipelineResult + +if TYPE_CHECKING: # pragma: NO COVER + import datetime + from google.cloud.firestore_v1.async_client import AsyncClient + from google.cloud.firestore_v1.async_transaction import AsyncTransaction + from google.cloud.firestore_v1.pipeline_expressions import Constant + from google.cloud.firestore_v1.types.document import Value + from google.cloud.firestore_v1.query_profile import PipelineExplainOptions + + +class AsyncPipeline(_BasePipeline): + """ + Pipelines allow for complex data transformations and queries involving + multiple stages like filtering, projection, aggregation, and vector search. + + This class extends `_BasePipeline` and provides methods to execute the + defined pipeline stages using an asynchronous `AsyncClient`. + + Usage Example: + >>> from google.cloud.firestore_v1.pipeline_expressions import Field + >>> + >>> async def run_pipeline(): + ... client = AsyncClient(...) + ... pipeline = client.pipeline() + ... .collection("books") + ... .where(Field.of("published").gt(1980)) + ... .select("title", "author") + ... async for result in pipeline.stream(): + ... print(result) + + Use `client.pipeline()` to create instances of this class. + + .. warning:: + **Preview API**: Firestore Pipelines is currently in preview and is + subject to potential breaking changes in future releases + """ + + def __init__(self, client: AsyncClient, *stages: stages.Stage): + """ + Initializes an asynchronous Pipeline. + + Args: + client: The asynchronous `AsyncClient` instance to use for execution. + *stages: Initial stages for the pipeline. + """ + super().__init__(client, *stages) + + async def execute( + self, + *, + transaction: "AsyncTransaction" | None = None, + read_time: datetime.datetime | None = None, + explain_options: PipelineExplainOptions | None = None, + additional_options: dict[str, Value | Constant] = {}, + ) -> PipelineSnapshot[PipelineResult]: + """ + Executes this pipeline and returns results as a list + + Args: + transaction (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): + An existing transaction that this query will run in. + If a ``transaction`` is used and it already has write operations + added, this method cannot be used (i.e. read-after-write is not + allowed). + read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given + time. This must be a microsecond precision timestamp within the past one hour, or + if Point-in-Time Recovery is enabled, can additionally be a whole minute timestamp + within the past 7 days. For the most accurate results, use UTC timezone. + explain_options (Optional[:class:`~google.cloud.firestore_v1.query_profile.PipelineExplainOptions`]): + Options to enable query profiling for this query. When set, + explain_metrics will be available on the returned list. + additional_options (Optional[dict[str, Value | Constant]]): Additional options to pass to the query. + These options will take precedence over method argument if there is a conflict (e.g. explain_options) + """ + kwargs = {k: v for k, v in locals().items() if k != "self"} + stream = AsyncPipelineStream(PipelineResult, self, **kwargs) + results = [result async for result in stream] + return PipelineSnapshot(results, stream) + + def stream( + self, + *, + read_time: datetime.datetime | None = None, + transaction: "AsyncTransaction" | None = None, + explain_options: PipelineExplainOptions | None = None, + additional_options: dict[str, Value | Constant] = {}, + ) -> AsyncPipelineStream[PipelineResult]: + """ + Process this pipeline as a stream, providing results through an AsyncIterable + + Args: + transaction (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): + An existing transaction that this query will run in. + If a ``transaction`` is used and it already has write operations + added, this method cannot be used (i.e. read-after-write is not + allowed). + read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given + time. This must be a microsecond precision timestamp within the past one hour, or + if Point-in-Time Recovery is enabled, can additionally be a whole minute timestamp + within the past 7 days. For the most accurate results, use UTC timezone. + explain_options (Optional[:class:`~google.cloud.firestore_v1.query_profile.PipelineExplainOptions`]): + Options to enable query profiling for this query. When set, + explain_metrics will be available on the returned generator. + additional_options (Optional[dict[str, Value | Constant]]): Additional options to pass to the query. + These options will take precedence over method argument if there is a conflict (e.g. explain_options) + """ + kwargs = {k: v for k, v in locals().items() if k != "self"} + return AsyncPipelineStream(PipelineResult, self, **kwargs) diff --git a/google/cloud/firestore_v1/base_aggregation.py b/google/cloud/firestore_v1/base_aggregation.py index c5e6a7b7f..6f392207e 100644 --- a/google/cloud/firestore_v1/base_aggregation.py +++ b/google/cloud/firestore_v1/base_aggregation.py @@ -21,9 +21,10 @@ from __future__ import annotations import abc +import itertools from abc import ABC -from typing import TYPE_CHECKING, Any, Coroutine, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Coroutine, List, Optional, Tuple, Union, Iterable from google.api_core import gapic_v1 from google.api_core import retry as retries @@ -33,6 +34,10 @@ from google.cloud.firestore_v1.types import ( StructuredAggregationQuery, ) +from google.cloud.firestore_v1.pipeline_expressions import AggregateFunction +from google.cloud.firestore_v1.pipeline_expressions import Count +from google.cloud.firestore_v1.pipeline_expressions import AliasedExpression +from google.cloud.firestore_v1.pipeline_expressions import Field # Types needed only for Type Hints if TYPE_CHECKING: # pragma: NO COVER @@ -43,6 +48,7 @@ from google.cloud.firestore_v1.stream_generator import ( StreamGenerator, ) + from google.cloud.firestore_v1.pipeline_source import PipelineSource import datetime @@ -66,6 +72,9 @@ def __init__(self, alias: str, value: float, read_time=None): def __repr__(self): return f"" + def _to_dict(self): + return {self.alias: self.value} + class BaseAggregation(ABC): def __init__(self, alias: str | None = None): @@ -75,6 +84,27 @@ def __init__(self, alias: str | None = None): def _to_protobuf(self): """Convert this instance to the protobuf representation""" + @abc.abstractmethod + def _to_pipeline_expr( + self, autoindexer: Iterable[int] + ) -> AliasedExpression[AggregateFunction]: + """ + Convert this instance to a pipeline expression for use with pipeline.aggregate() + + Args: + autoindexer: If an alias isn't supplied, one should be created with the format "field_n" + The autoindexer is an iterable that provides the `n` value to use for each expression + """ + + def _pipeline_alias(self, autoindexer): + """ + Helper to build the alias for the pipeline expression + """ + if self.alias is not None: + return self.alias + else: + return f"field_{next(autoindexer)}" + class CountAggregation(BaseAggregation): def __init__(self, alias: str | None = None): @@ -88,6 +118,9 @@ def _to_protobuf(self): aggregation_pb.count = StructuredAggregationQuery.Aggregation.Count() return aggregation_pb + def _to_pipeline_expr(self, autoindexer: Iterable[int]): + return Count().as_(self._pipeline_alias(autoindexer)) + class SumAggregation(BaseAggregation): def __init__(self, field_ref: str | FieldPath, alias: str | None = None): @@ -107,6 +140,9 @@ def _to_protobuf(self): aggregation_pb.sum.field.field_path = self.field_ref return aggregation_pb + def _to_pipeline_expr(self, autoindexer: Iterable[int]): + return Field.of(self.field_ref).sum().as_(self._pipeline_alias(autoindexer)) + class AvgAggregation(BaseAggregation): def __init__(self, field_ref: str | FieldPath, alias: str | None = None): @@ -126,6 +162,9 @@ def _to_protobuf(self): aggregation_pb.avg.field.field_path = self.field_ref return aggregation_pb + def _to_pipeline_expr(self, autoindexer: Iterable[int]): + return Field.of(self.field_ref).average().as_(self._pipeline_alias(autoindexer)) + def _query_response_to_result( response_pb, @@ -317,3 +356,21 @@ def stream( StreamGenerator[List[AggregationResult]] | AsyncStreamGenerator[List[AggregationResult]]: A generator of the query results. """ + + def _build_pipeline(self, source: "PipelineSource"): + """ + Convert this query into a Pipeline + + Queries containing a `cursor` or `limit_to_last` are not currently supported + + Args: + source: the PipelineSource to build the pipeline off of + Raises: + - NotImplementedError: raised if the query contains a `cursor` or `limit_to_last` + Returns: + a Pipeline representing the query + """ + # use autoindexer to keep track of which field number to use for un-aliased fields + autoindexer = itertools.count(start=1) + exprs = [a._to_pipeline_expr(autoindexer) for a in self._aggregations] + return self._nested_query._build_pipeline(source).aggregate(*exprs) diff --git a/google/cloud/firestore_v1/base_client.py b/google/cloud/firestore_v1/base_client.py index f3eeeae49..4ba8a7c06 100644 --- a/google/cloud/firestore_v1/base_client.py +++ b/google/cloud/firestore_v1/base_client.py @@ -37,6 +37,7 @@ Optional, Tuple, Union, + Type, ) import google.api_core.client_options @@ -61,6 +62,8 @@ from google.cloud.firestore_v1.bulk_writer import BulkWriter, BulkWriterOptions from google.cloud.firestore_v1.field_path import render_field_path from google.cloud.firestore_v1.services.firestore import client as firestore_client +from google.cloud.firestore_v1.pipeline_source import PipelineSource +from google.cloud.firestore_v1.base_pipeline import _BasePipeline DEFAULT_DATABASE = "(default)" """str: The default database used in a :class:`~google.cloud.firestore_v1.client.Client`.""" @@ -502,6 +505,20 @@ def transaction( ) -> BaseTransaction: raise NotImplementedError + def pipeline(self) -> PipelineSource: + """ + Start a pipeline with this client. + + Returns: + :class:`~google.cloud.firestore_v1.pipeline_source.PipelineSource`: + A pipeline that uses this client` + """ + raise NotImplementedError + + @property + def _pipeline_cls(self) -> Type["_BasePipeline"]: + raise NotImplementedError + def _reference_info(references: list) -> Tuple[list, dict]: """Get information about document references. diff --git a/google/cloud/firestore_v1/base_collection.py b/google/cloud/firestore_v1/base_collection.py index be817c5fe..070e54cc4 100644 --- a/google/cloud/firestore_v1/base_collection.py +++ b/google/cloud/firestore_v1/base_collection.py @@ -49,6 +49,7 @@ from google.cloud.firestore_v1.async_document import AsyncDocumentReference from google.cloud.firestore_v1.document import DocumentReference from google.cloud.firestore_v1.field_path import FieldPath + from google.cloud.firestore_v1.pipeline_source import PipelineSource from google.cloud.firestore_v1.query_profile import ExplainOptions from google.cloud.firestore_v1.query_results import QueryResultsList from google.cloud.firestore_v1.stream_generator import StreamGenerator @@ -603,6 +604,21 @@ def find_nearest( distance_threshold=distance_threshold, ) + def _build_pipeline(self, source: "PipelineSource"): + """ + Convert this query into a Pipeline + + Queries containing a `cursor` or `limit_to_last` are not currently supported + + Args: + source: the PipelineSource to build the pipeline off o + Raises: + - NotImplementedError: raised if the query contains a `cursor` or `limit_to_last` + Returns: + a Pipeline representing the query + """ + return self._query()._build_pipeline(source) + def _auto_id() -> str: """Generate a "random" automatically generated ID. diff --git a/google/cloud/firestore_v1/base_pipeline.py b/google/cloud/firestore_v1/base_pipeline.py new file mode 100644 index 000000000..153564663 --- /dev/null +++ b/google/cloud/firestore_v1/base_pipeline.py @@ -0,0 +1,610 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations +from typing import Sequence, TYPE_CHECKING +from google.cloud.firestore_v1 import pipeline_stages as stages +from google.cloud.firestore_v1.types.pipeline import ( + StructuredPipeline as StructuredPipeline_pb, +) +from google.cloud.firestore_v1.vector import Vector +from google.cloud.firestore_v1.base_vector_query import DistanceMeasure +from google.cloud.firestore_v1.pipeline_expressions import ( + AggregateFunction, + AliasedExpression, + Expression, + Field, + BooleanExpression, + Selectable, +) + +if TYPE_CHECKING: # pragma: NO COVER + from google.cloud.firestore_v1.client import Client + from google.cloud.firestore_v1.async_client import AsyncClient + + +class _BasePipeline: + """ + Base class for building Firestore data transformation and query pipelines. + + This class is not intended to be instantiated directly. + Use `client.pipeline()` to create pipeline instances. + """ + + def __init__(self, client: Client | AsyncClient): + """ + Initializes a new pipeline. + + Pipelines should not be instantiated directly. Instead, + call client.pipeline() to create an instance + + Args: + client: The client associated with the pipeline + """ + self._client = client + self.stages: Sequence[stages.Stage] = tuple() + + @classmethod + def _create_with_stages( + cls, client: Client | AsyncClient, *stages + ) -> _BasePipeline: + """ + Initializes a new pipeline with the given stages. + + Pipeline classes should not be instantiated directly. + + Args: + client: The client associated with the pipeline + *stages: Initial stages for the pipeline. + """ + new_instance = cls(client) + new_instance.stages = tuple(stages) + return new_instance + + def __repr__(self): + cls_str = type(self).__name__ + if not self.stages: + return f"{cls_str}()" + elif len(self.stages) == 1: + return f"{cls_str}({self.stages[0]!r})" + else: + stages_str = ",\n ".join([repr(s) for s in self.stages]) + return f"{cls_str}(\n {stages_str}\n)" + + def _to_pb(self, **options) -> StructuredPipeline_pb: + return StructuredPipeline_pb( + pipeline={"stages": [s._to_pb() for s in self.stages]}, + options=options, + ) + + def _append(self, new_stage): + """ + Create a new Pipeline object with a new stage appended + """ + return self.__class__._create_with_stages(self._client, *self.stages, new_stage) + + def add_fields(self, *fields: Selectable) -> "_BasePipeline": + """ + Adds new fields to outputs from previous stages. + + This stage allows you to compute values on-the-fly based on existing data + from previous stages or constants. You can use this to create new fields + or overwrite existing ones (if there is name overlap). + + The added fields are defined using `Selectable` expressions, which can be: + - `Field`: References an existing document field. + - `Function`: Performs a calculation using functions like `add`, + `multiply` with assigned aliases using `Expression.as_()`. + + Example: + >>> from google.cloud.firestore_v1.pipeline_expressions import Field, add + >>> pipeline = client.pipeline().collection("books") + >>> pipeline = pipeline.add_fields( + ... Field.of("rating").as_("bookRating"), # Rename 'rating' to 'bookRating' + ... add(5, Field.of("quantity")).as_("totalCost") # Calculate 'totalCost' + ... ) + + Args: + *fields: The fields to add to the documents, specified as `Selectable` + expressions. + Returns: + A new Pipeline object with this stage appended to the stage list + """ + return self._append(stages.AddFields(*fields)) + + def remove_fields(self, *fields: Field | str) -> "_BasePipeline": + """ + Removes fields from outputs of previous stages. + + Example: + >>> from google.cloud.firestore_v1.pipeline_expressions import Field + >>> pipeline = client.pipeline().collection("books") + >>> # Remove by name + >>> pipeline = pipeline.remove_fields("rating", "cost") + >>> # Remove by Field object + >>> pipeline = pipeline.remove_fields(Field.of("rating"), Field.of("cost")) + + + Args: + *fields: The fields to remove, specified as field names (str) or + `Field` objects. + + Returns: + A new Pipeline object with this stage appended to the stage list + """ + return self._append(stages.RemoveFields(*fields)) + + def select(self, *selections: str | Selectable) -> "_BasePipeline": + """ + Selects or creates a set of fields from the outputs of previous stages. + + The selected fields are defined using `Selectable` expressions or field names: + - `Field`: References an existing document field. + - `Function`: Represents the result of a function with an assigned alias + name using `Expression.as_()`. + - `str`: The name of an existing field. + + If no selections are provided, the output of this stage is empty. Use + `add_fields()` instead if only additions are desired. + + Example: + >>> from google.cloud.firestore_v1.pipeline_expressions import Field, to_upper + >>> pipeline = client.pipeline().collection("books") + >>> # Select by name + >>> pipeline = pipeline.select("name", "address") + >>> # Select using Field and Function expressions + >>> pipeline = pipeline.select( + ... Field.of("name"), + ... Field.of("address").to_upper().as_("upperAddress"), + ... ) + + Args: + *selections: The fields to include in the output documents, specified as + field names (str) or `Selectable` expressions. + + Returns: + A new Pipeline object with this stage appended to the stage list + """ + return self._append(stages.Select(*selections)) + + def where(self, condition: BooleanExpression) -> "_BasePipeline": + """ + Filters the documents from previous stages to only include those matching + the specified `BooleanExpression`. + + This stage allows you to apply conditions to the data, similar to a "WHERE" + clause in SQL. You can filter documents based on their field values, using + implementations of `BooleanExpression`, typically including but not limited to: + - field comparators: `eq`, `lt` (less than), `gt` (greater than), etc. + - logical operators: `And`, `Or`, `Not`, etc. + - advanced functions: `regex_matches`, `array_contains`, etc. + + Example: + >>> from google.cloud.firestore_v1.pipeline_expressions import Field, And, + >>> pipeline = client.pipeline().collection("books") + >>> # Using static functions + >>> pipeline = pipeline.where( + ... And( + ... Field.of("rating").gt(4.0), # Filter for ratings > 4.0 + ... Field.of("genre").eq("Science Fiction") # Filter for genre + ... ) + ... ) + >>> # Using methods on expressions + >>> pipeline = pipeline.where( + ... And( + ... Field.of("rating").gt(4.0), + ... Field.of("genre").eq("Science Fiction") + ... ) + ... ) + + + Args: + condition: The `BooleanExpression` to apply. + + Returns: + A new Pipeline object with this stage appended to the stage list + """ + return self._append(stages.Where(condition)) + + def find_nearest( + self, + field: str | Expression, + vector: Sequence[float] | "Vector", + distance_measure: "DistanceMeasure", + options: stages.FindNearestOptions | None = None, + ) -> "_BasePipeline": + """ + Performs vector distance (similarity) search with given parameters on the + stage inputs. + + This stage adds a "nearest neighbor search" capability to your pipelines. + Given a field or expression that evaluates to a vector and a target vector, + this stage will identify and return the inputs whose vector is closest to + the target vector, using the specified distance measure and options. + + Example: + >>> from google.cloud.firestore_v1.base_vector_query import DistanceMeasure + >>> from google.cloud.firestore_v1.pipeline_stages import FindNearestOptions + >>> from google.cloud.firestore_v1.pipeline_expressions import Field + >>> + >>> target_vector = [0.1, 0.2, 0.3] + >>> pipeline = client.pipeline().collection("books") + >>> # Find using field name + >>> pipeline = pipeline.find_nearest( + ... "topicVectors", + ... target_vector, + ... DistanceMeasure.COSINE, + ... options=FindNearestOptions(limit=10, distance_field="distance") + ... ) + >>> # Find using Field expression + >>> pipeline = pipeline.find_nearest( + ... Field.of("topicVectors"), + ... target_vector, + ... DistanceMeasure.COSINE, + ... options=FindNearestOptions(limit=10, distance_field="distance") + ... ) + + Args: + field: The name of the field (str) or an expression (`Expression`) that + evaluates to the vector data. This field should store vector values. + vector: The target vector (sequence of floats or `Vector` object) to + compare against. + distance_measure: The distance measure (`DistanceMeasure`) to use + (e.g., `DistanceMeasure.COSINE`, `DistanceMeasure.EUCLIDEAN`). + limit: The maximum number of nearest neighbors to return. + options: Configuration options (`FindNearestOptions`) for the search, + such as limit and output distance field name. + + Returns: + A new Pipeline object with this stage appended to the stage list + """ + return self._append( + stages.FindNearest(field, vector, distance_measure, options) + ) + + def replace_with( + self, + field: Selectable, + ) -> "_BasePipeline": + """ + Fully overwrites all fields in a document with those coming from a nested map. + + This stage allows you to emit a map value as a document. Each key of the map becomes a field + on the document that contains the corresponding value. + + Example: + Input document: + ```json + { + "name": "John Doe Jr.", + "parents": { + "father": "John Doe Sr.", + "mother": "Jane Doe" + } + } + ``` + + >>> # Emit the 'parents' map as the document + >>> pipeline = client.pipeline().collection("people").replace_with(Field.of("parents")) + + Output document: + ```json + { + "father": "John Doe Sr.", + "mother": "Jane Doe" + } + ``` + + Args: + field: The `Selectable` field containing the map whose content will + replace the document. + Returns: + A new Pipeline object with this stage appended to the stage list + """ + return self._append(stages.ReplaceWith(field)) + + def sort(self, *orders: stages.Ordering) -> "_BasePipeline": + """ + Sorts the documents from previous stages based on one or more `Ordering` criteria. + + This stage allows you to order the results of your pipeline. You can specify + multiple `Ordering` instances to sort by multiple fields or expressions in + ascending or descending order. If documents have the same value for a sorting + criterion, the next specified ordering will be used. If all orderings result + in equal comparison, the documents are considered equal and the relative order + is unspecified. + + Example: + >>> from google.cloud.firestore_v1.pipeline_expressions import Field + >>> pipeline = client.pipeline().collection("books") + >>> # Sort books by rating descending, then title ascending + >>> pipeline = pipeline.sort( + ... Field.of("rating").descending(), + ... Field.of("title").ascending() + ... ) + + Args: + *orders: One or more `Ordering` instances specifying the sorting criteria. + + Returns: + A new Pipeline object with this stage appended to the stage list + """ + return self._append(stages.Sort(*orders)) + + def sample(self, limit_or_options: int | stages.SampleOptions) -> "_BasePipeline": + """ + Performs a pseudo-random sampling of the documents from the previous stage. + + This stage filters documents pseudo-randomly. + - If an `int` limit is provided, it specifies the maximum number of documents + to emit. If fewer documents are available, all are passed through. + - If `SampleOptions` are provided, they specify how sampling is performed + (e.g., by document count or percentage). + + Example: + >>> from google.cloud.firestore_v1.pipeline_expressions import SampleOptions + >>> pipeline = client.pipeline().collection("books") + >>> # Sample 10 books, if available. + >>> pipeline = pipeline.sample(10) + >>> pipeline = pipeline.sample(SampleOptions.doc_limit(10)) + >>> # Sample 50% of books. + >>> pipeline = pipeline.sample(SampleOptions.percentage(0.5)) + + + Args: + limit_or_options: Either an integer specifying the maximum number of + documents to sample, or a `SampleOptions` object. + + Returns: + A new Pipeline object with this stage appended to the stage list + """ + return self._append(stages.Sample(limit_or_options)) + + def union(self, other: "_BasePipeline") -> "_BasePipeline": + """ + Performs a union of all documents from this pipeline and another pipeline, + including duplicates. + + This stage passes through documents from the previous stage of this pipeline, + and also passes through documents from the previous stage of the `other` + pipeline provided. The order of documents emitted from this stage is undefined. + + Example: + >>> books_pipeline = client.pipeline().collection("books") + >>> magazines_pipeline = client.pipeline().collection("magazines") + >>> # Emit documents from both collections + >>> combined_pipeline = books_pipeline.union(magazines_pipeline) + + Args: + other: The other `Pipeline` whose results will be unioned with this one. + + Returns: + A new Pipeline object with this stage appended to the stage list + """ + return self._append(stages.Union(other)) + + def unnest( + self, + field: str | Selectable, + alias: str | Field | None = None, + options: stages.UnnestOptions | None = None, + ) -> "_BasePipeline": + """ + Produces a document for each element in an array field from the previous stage document. + + For each previous stage document, this stage will emit zero or more augmented documents. The + input array found in the previous stage document field specified by the `fieldName` parameter, + will emit an augmented document for each input array element. The input array element will + augment the previous stage document by setting the `alias` field with the array element value. + If `alias` is unset, the data in `field` will be overwritten. + + Example: + Input document: + ```json + { "title": "The Hitchhiker's Guide", "tags": [ "comedy", "sci-fi" ], ... } + ``` + + >>> from google.cloud.firestore_v1.pipeline_stages import UnnestOptions + >>> pipeline = client.pipeline().collection("books") + >>> # Emit a document for each tag + >>> pipeline = pipeline.unnest("tags", alias="tag") + + Output documents (without options): + ```json + { "title": "The Hitchhiker's Guide", "tag": "comedy", ... } + { "title": "The Hitchhiker's Guide", "tag": "sci-fi", ... } + ``` + + Optionally, `UnnestOptions` can specify a field to store the original index + of the element within the array + + Example: + Input document: + ```json + { "title": "The Hitchhiker's Guide", "tags": [ "comedy", "sci-fi" ], ... } + ``` + + >>> from google.cloud.firestore_v1.pipeline_stages import UnnestOptions + >>> pipeline = client.pipeline().collection("books") + >>> # Emit a document for each tag, including the index + >>> pipeline = pipeline.unnest("tags", options=UnnestOptions(index_field="tagIndex")) + + Output documents (with index_field="tagIndex"): + ```json + { "title": "The Hitchhiker's Guide", "tags": "comedy", "tagIndex": 0, ... } + { "title": "The Hitchhiker's Guide", "tags": "sci-fi", "tagIndex": 1, ... } + ``` + + Args: + field: The name of the field containing the array to unnest. + alias The alias field is used as the field name for each element within the output array. + If unset, or if `alias` matches the `field`, the output data will overwrite the original field. + options: Optional `UnnestOptions` to configure additional behavior, like adding an index field. + + Returns: + A new Pipeline object with this stage appended to the stage list + """ + return self._append(stages.Unnest(field, alias, options)) + + def raw_stage(self, name: str, *params: Expression) -> "_BasePipeline": + """ + Adds a stage to the pipeline by specifying the stage name as an argument. This does not offer any + type safety on the stage params and requires the caller to know the order (and optionally names) + of parameters accepted by the stage. + + This class provides a way to call stages that are supported by the Firestore backend but that + are not implemented in the SDK version being used. + + Example: + >>> # Assume we don't have a built-in "where" stage + >>> pipeline = client.pipeline().collection("books") + >>> pipeline = pipeline.raw_stage("where", Field.of("published").lt(900)) + >>> pipeline = pipeline.select("title", "author") + + Args: + name: The name of the stage. + *params: A sequence of `Expression` objects representing the parameters for the stage. + + Returns: + A new Pipeline object with this stage appended to the stage list + """ + return self._append(stages.RawStage(name, *params)) + + def offset(self, offset: int) -> "_BasePipeline": + """ + Skips the first `offset` number of documents from the results of previous stages. + + This stage is useful for implementing pagination, allowing you to retrieve + results in chunks. It is typically used in conjunction with `limit()` to + control the size of each page. + + Example: + >>> from google.cloud.firestore_v1.pipeline_expressions import Field + >>> pipeline = client.pipeline().collection("books") + >>> # Retrieve the second page of 20 results (assuming sorted) + >>> pipeline = pipeline.sort(Field.of("published").descending()) + >>> pipeline = pipeline.offset(20) # Skip the first 20 results + >>> pipeline = pipeline.limit(20) # Take the next 20 results + + Args: + offset: The non-negative number of documents to skip. + + Returns: + A new Pipeline object with this stage appended to the stage list + """ + return self._append(stages.Offset(offset)) + + def limit(self, limit: int) -> "_BasePipeline": + """ + Limits the maximum number of documents returned by previous stages to `limit`. + + This stage is useful for controlling the size of the result set, often used for: + - **Pagination:** In combination with `offset()` to retrieve specific pages. + - **Top-N queries:** To get a limited number of results after sorting. + - **Performance:** To prevent excessive data transfer. + + Example: + >>> from google.cloud.firestore_v1.pipeline_expressions import Field + >>> pipeline = client.pipeline().collection("books") + >>> # Limit the results to the top 10 highest-rated books + >>> pipeline = pipeline.sort(Field.of("rating").descending()) + >>> pipeline = pipeline.limit(10) + + Args: + limit: The non-negative maximum number of documents to return. + + Returns: + A new Pipeline object with this stage appended to the stage list + """ + return self._append(stages.Limit(limit)) + + def aggregate( + self, + *accumulators: AliasedExpression[AggregateFunction], + groups: Sequence[str | Selectable] = (), + ) -> "_BasePipeline": + """ + Performs aggregation operations on the documents from previous stages, + optionally grouped by specified fields or expressions. + + This stage allows you to calculate aggregate values (like sum, average, count, + min, max) over a set of documents. + + - **Accumulators:** Define the aggregation calculations using `AggregateFunction` + expressions (e.g., `sum()`, `avg()`, `count()`, `min()`, `max()`) combined + with `as_()` to name the result field. + - **Groups:** Optionally specify fields (by name or `Selectable`) to group + the documents by. Aggregations are then performed within each distinct group. + If no groups are provided, the aggregation is performed over the entire input. + Example: + >>> from google.cloud.firestore_v1.pipeline_expressions import Field + >>> pipeline = client.pipeline().collection("books") + >>> # Calculate the average rating and total count for all books + >>> pipeline = pipeline.aggregate( + ... Field.of("rating").avg().as_("averageRating"), + ... Field.of("rating").count().as_("totalBooks") + ... ) + >>> # Calculate the average rating for each genre + >>> pipeline = pipeline.aggregate( + ... Field.of("rating").avg().as_("avg_rating"), + ... groups=["genre"] # Group by the 'genre' field + ... ) + >>> # Calculate the count for each author, grouping by Field object + >>> pipeline = pipeline.aggregate( + ... Count().as_("bookCount"), + ... groups=[Field.of("author")] + ... ) + + + Args: + *accumulators: One or more expressions defining the aggregations to perform and their + corresponding output names. + groups: An optional sequence of field names (str) or `Selectable` + expressions to group by before aggregating. + + Returns: + A new Pipeline object with this stage appended to the stage list + """ + return self._append(stages.Aggregate(*accumulators, groups=groups)) + + def distinct(self, *fields: str | Selectable) -> "_BasePipeline": + """ + Returns documents with distinct combinations of values for the specified + fields or expressions. + + This stage filters the results from previous stages to include only one + document for each unique combination of values in the specified `fields`. + The output documents contain only the fields specified in the `distinct` call. + + Example: + >>> from google.cloud.firestore_v1.pipeline_expressions import Field, to_upper + >>> pipeline = client.pipeline().collection("books") + >>> # Get a list of unique genres (output has only 'genre' field) + >>> pipeline = pipeline.distinct("genre") + >>> # Get unique combinations of author (uppercase) and genre + >>> pipeline = pipeline.distinct( + ... Field.of("author").to_upper().as_("authorUpper"), + ... Field.of("genre") + ... ) + + + Args: + *fields: Field names (str) or `Selectable` expressions to consider when + determining distinct value combinations. The output will only + contain these fields/expressions. + + Returns: + A new Pipeline object with this stage appended to the stage list + """ + return self._append(stages.Distinct(*fields)) diff --git a/google/cloud/firestore_v1/base_query.py b/google/cloud/firestore_v1/base_query.py index 2de95b79a..b1b74fcf1 100644 --- a/google/cloud/firestore_v1/base_query.py +++ b/google/cloud/firestore_v1/base_query.py @@ -59,6 +59,7 @@ query, ) from google.cloud.firestore_v1.vector import Vector +from google.cloud.firestore_v1 import pipeline_expressions if TYPE_CHECKING: # pragma: NO COVER from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator @@ -66,6 +67,7 @@ from google.cloud.firestore_v1.query_profile import ExplainOptions from google.cloud.firestore_v1.query_results import QueryResultsList from google.cloud.firestore_v1.stream_generator import StreamGenerator + from google.cloud.firestore_v1.pipeline_source import PipelineSource import datetime @@ -1128,6 +1130,73 @@ def recursive(self: QueryType) -> QueryType: return copied + def _build_pipeline(self, source: "PipelineSource"): + """ + Convert this query into a Pipeline + + Queries containing a `cursor` or `limit_to_last` are not currently supported + + Args: + source: the PipelineSource to build the pipeline off of + Raises: + - NotImplementedError: raised if the query contains a `cursor` or `limit_to_last` + Returns: + a Pipeline representing the query + """ + if self._all_descendants: + ppl = source.collection_group(self._parent.id) + else: + ppl = source.collection(self._parent._path) + + # Filters + for filter_ in self._field_filters: + ppl = ppl.where( + pipeline_expressions.BooleanExpression._from_query_filter_pb( + filter_, self._client + ) + ) + + # Projections + if self._projection and self._projection.fields: + ppl = ppl.select(*[field.field_path for field in self._projection.fields]) + + # Orders + orders = self._normalize_orders() + if orders: + exists = [] + orderings = [] + for order in orders: + field = pipeline_expressions.Field.of(order.field.field_path) + exists.append(field.exists()) + direction = ( + "ascending" + if order.direction == StructuredQuery.Direction.ASCENDING + else "descending" + ) + orderings.append(pipeline_expressions.Ordering(field, direction)) + + # Add exists filters to match Query's implicit orderby semantics. + if len(exists) == 1: + ppl = ppl.where(exists[0]) + else: + ppl = ppl.where(pipeline_expressions.And(*exists)) + + # Add sort orderings + ppl = ppl.sort(*orderings) + + # Cursors, Limit and Offset + if self._start_at or self._end_at or self._limit_to_last: + raise NotImplementedError( + "Query to Pipeline conversion: cursors and limit_to_last is not supported yet." + ) + else: # Limit & Offset without cursors + if self._offset: + ppl = ppl.offset(self._offset) + if self._limit: + ppl = ppl.limit(self._limit) + + return ppl + def _comparator(self, doc1, doc2) -> int: _orders = self._orders diff --git a/google/cloud/firestore_v1/client.py b/google/cloud/firestore_v1/client.py index 54943aded..ba2ca176d 100644 --- a/google/cloud/firestore_v1/client.py +++ b/google/cloud/firestore_v1/client.py @@ -50,6 +50,8 @@ grpc as firestore_grpc_transport, ) from google.cloud.firestore_v1.transaction import Transaction +from google.cloud.firestore_v1.pipeline import Pipeline +from google.cloud.firestore_v1.pipeline_source import PipelineSource if TYPE_CHECKING: # pragma: NO COVER from google.cloud.firestore_v1.bulk_writer import BulkWriter @@ -411,3 +413,10 @@ def transaction( A transaction attached to this client. """ return Transaction(self, max_attempts=max_attempts, read_only=read_only) + + @property + def _pipeline_cls(self): + return Pipeline + + def pipeline(self) -> PipelineSource: + return PipelineSource(self) diff --git a/google/cloud/firestore_v1/field_path.py b/google/cloud/firestore_v1/field_path.py index 27ac6cc45..32516d3be 100644 --- a/google/cloud/firestore_v1/field_path.py +++ b/google/cloud/firestore_v1/field_path.py @@ -16,7 +16,7 @@ from __future__ import annotations import re from collections import abc -from typing import Iterable, cast +from typing import Any, Iterable, cast, MutableMapping _FIELD_PATH_MISSING_TOP = "{!r} is not contained in the data" _FIELD_PATH_MISSING_KEY = "{!r} is not contained in the data for the key {!r}" @@ -170,7 +170,7 @@ def render_field_path(field_names: Iterable[str]): get_field_path = render_field_path # backward-compatibility -def get_nested_value(field_path: str, data: dict): +def get_nested_value(field_path: str, data: MutableMapping[str, Any]): """Get a (potentially nested) value from a dictionary. If the data is nested, for example: diff --git a/google/cloud/firestore_v1/pipeline.py b/google/cloud/firestore_v1/pipeline.py new file mode 100644 index 000000000..bce43fc86 --- /dev/null +++ b/google/cloud/firestore_v1/pipeline.py @@ -0,0 +1,131 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +.. warning:: + **Preview API**: Firestore Pipelines is currently in preview and is + subject to potential breaking changes in future releases. +""" + +from __future__ import annotations +from typing import TYPE_CHECKING +from google.cloud.firestore_v1 import pipeline_stages as stages +from google.cloud.firestore_v1.base_pipeline import _BasePipeline +from google.cloud.firestore_v1.pipeline_result import PipelineStream +from google.cloud.firestore_v1.pipeline_result import PipelineSnapshot +from google.cloud.firestore_v1.pipeline_result import PipelineResult + +if TYPE_CHECKING: # pragma: NO COVER + import datetime + from google.cloud.firestore_v1.client import Client + from google.cloud.firestore_v1.pipeline_expressions import Constant + from google.cloud.firestore_v1.transaction import Transaction + from google.cloud.firestore_v1.types.document import Value + from google.cloud.firestore_v1.query_profile import PipelineExplainOptions + + +class Pipeline(_BasePipeline): + """ + Pipelines allow for complex data transformations and queries involving + multiple stages like filtering, projection, aggregation, and vector search. + + Usage Example: + >>> from google.cloud.firestore_v1.pipeline_expressions import Field + >>> + >>> def run_pipeline(): + ... client = Client(...) + ... pipeline = client.pipeline() + ... .collection("books") + ... .where(Field.of("published").gt(1980)) + ... .select("title", "author") + ... for result in pipeline.execute(): + ... print(result) + + Use `client.pipeline()` to create instances of this class. + + .. warning:: + **Preview API**: Firestore Pipelines is currently in preview and is + subject to potential breaking changes in future releases. + """ + + def __init__(self, client: Client, *stages: stages.Stage): + """ + Initializes a Pipeline. + + Args: + client: The `Client` instance to use for execution. + *stages: Initial stages for the pipeline. + """ + super().__init__(client, *stages) + + def execute( + self, + *, + transaction: "Transaction" | None = None, + read_time: datetime.datetime | None = None, + explain_options: PipelineExplainOptions | None = None, + additional_options: dict[str, Value | Constant] = {}, + ) -> PipelineSnapshot[PipelineResult]: + """ + Executes this pipeline and returns results as a list + + Args: + transaction (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): + An existing transaction that this query will run in. + If a ``transaction`` is used and it already has write operations + added, this method cannot be used (i.e. read-after-write is not + allowed). + read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given + time. This must be a microsecond precision timestamp within the past one hour, or + if Point-in-Time Recovery is enabled, can additionally be a whole minute timestamp + within the past 7 days. For the most accurate results, use UTC timezone. + explain_options (Optional[:class:`~google.cloud.firestore_v1.query_profile.PipelineExplainOptions`]): + Options to enable query profiling for this query. When set, + explain_metrics will be available on the returned list. + additional_options (Optional[dict[str, Value | Constant]]): Additional options to pass to the query. + These options will take precedence over method argument if there is a conflict (e.g. explain_options) + """ + kwargs = {k: v for k, v in locals().items() if k != "self"} + stream = PipelineStream(PipelineResult, self, **kwargs) + results = [result for result in stream] + return PipelineSnapshot(results, stream) + + def stream( + self, + *, + transaction: "Transaction" | None = None, + read_time: datetime.datetime | None = None, + explain_options: PipelineExplainOptions | None = None, + additional_options: dict[str, Value | Constant] = {}, + ) -> PipelineStream[PipelineResult]: + """ + Process this pipeline as a stream, providing results through an Iterable + + Args: + transaction (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): + An existing transaction that this query will run in. + If a ``transaction`` is used and it already has write operations + added, this method cannot be used (i.e. read-after-write is not + allowed). + read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given + time. This must be a microsecond precision timestamp within the past one hour, or + if Point-in-Time Recovery is enabled, can additionally be a whole minute timestamp + within the past 7 days. For the most accurate results, use UTC timezone. + explain_options (Optional[:class:`~google.cloud.firestore_v1.query_profile.PipelineExplainOptions`]): + Options to enable query profiling for this query. When set, + explain_metrics will be available on the returned generator. + additional_options (Optional[dict[str, Value | Constant]]): Additional options to pass to the query. + These options will take precedence over method argument if there is a conflict (e.g. explain_options) + """ + kwargs = {k: v for k, v in locals().items() if k != "self"} + return PipelineStream(PipelineResult, self, **kwargs) diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py new file mode 100644 index 000000000..c0ff3923a --- /dev/null +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -0,0 +1,2026 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +.. warning:: + **Preview API**: Firestore Pipelines is currently in preview and is + subject to potential breaking changes in future releases. +""" + +from __future__ import annotations +from typing import ( + Any, + Generic, + TypeVar, + Sequence, +) +from abc import ABC +from abc import abstractmethod +from enum import Enum +import datetime +from google.cloud.firestore_v1.types.document import Value +from google.cloud.firestore_v1.types.query import StructuredQuery as Query_pb +from google.cloud.firestore_v1.vector import Vector +from google.cloud.firestore_v1._helpers import GeoPoint +from google.cloud.firestore_v1._helpers import encode_value +from google.cloud.firestore_v1._helpers import decode_value + +CONSTANT_TYPE = TypeVar( + "CONSTANT_TYPE", + str, + int, + float, + bool, + datetime.datetime, + bytes, + GeoPoint, + Vector, + None, +) + + +class Ordering: + """Represents the direction for sorting results in a pipeline.""" + + class Direction(Enum): + ASCENDING = "ascending" + DESCENDING = "descending" + + def __init__(self, expr, order_dir: Direction | str = Direction.ASCENDING): + """ + Initializes an Ordering instance + + Args: + expr (Expression | str): The expression or field path string to sort by. + If a string is provided, it's treated as a field path. + order_dir (Direction | str): The direction to sort in. + Defaults to ascending + """ + self.expr = expr if isinstance(expr, Expression) else Field.of(expr) + self.order_dir = ( + Ordering.Direction[order_dir.upper()] + if isinstance(order_dir, str) + else order_dir + ) + + def __repr__(self): + if self.order_dir is Ordering.Direction.ASCENDING: + order_str = ".ascending()" + else: + order_str = ".descending()" + return f"{self.expr!r}{order_str}" + + def _to_pb(self) -> Value: + return Value( + map_value={ + "fields": { + "direction": Value(string_value=self.order_dir.value), + "expression": self.expr._to_pb(), + } + } + ) + + +class Expression(ABC): + """Represents an expression that can be evaluated to a value within the + execution of a pipeline. + + Expressionessions are the building blocks for creating complex queries and + transformations in Firestore pipelines. They can represent: + + - **Field references:** Access values from document fields. + - **Literals:** Represent constant values (strings, numbers, booleans). + - **FunctionExpression calls:** Apply functions to one or more expressions. + - **Aggregations:** Calculate aggregate values (e.g., sum, average) over a set of documents. + + The `Expression` class provides a fluent API for building expressions. You can chain + together method calls to create complex expressions. + """ + + def __repr__(self): + return f"{self.__class__.__name__}()" + + @abstractmethod + def _to_pb(self) -> Value: + raise NotImplementedError + + @staticmethod + def _cast_to_expr_or_convert_to_constant( + o: Any, include_vector=False + ) -> "Expression": + """Convert arbitrary object to an Expression.""" + if isinstance(o, Expression): + return o + if isinstance(o, dict): + return Map(o) + if isinstance(o, list): + if include_vector and all([isinstance(i, (float, int)) for i in o]): + return Constant(Vector(o)) + else: + return Array(o) + return Constant(o) + + class expose_as_static: + """ + Decorator to mark instance methods to be exposed as static methods as well as instance + methods. + + When called statically, the first argument is converted to a Field expression if needed. + + Example: + >>> Field.of("test").add(5) + >>> FunctionExpression.add("test", 5) + """ + + def __init__(self, instance_func): + self.instance_func = instance_func + + def static_func(self, first_arg, *other_args, **kwargs): + if not isinstance(first_arg, (Expression, str)): + raise TypeError( + f"'{self.instance_func.__name__}' must be called on an Expression or a string representing a field. got {type(first_arg)}." + ) + first_expr = ( + Field.of(first_arg) + if not isinstance(first_arg, Expression) + else first_arg + ) + return self.instance_func(first_expr, *other_args, **kwargs) + + def __get__(self, instance, owner): + if instance is None: + return self.static_func + else: + return self.instance_func.__get__(instance, owner) + + @expose_as_static + def add(self, other: Expression | float) -> "Expression": + """Creates an expression that adds this expression to another expression or constant. + + Example: + >>> # Add the value of the 'quantity' field and the 'reserve' field. + >>> Field.of("quantity").add(Field.of("reserve")) + >>> # Add 5 to the value of the 'age' field + >>> Field.of("age").add(5) + + Args: + other: The expression or constant value to add to this expression. + + Returns: + A new `Expression` representing the addition operation. + """ + return FunctionExpression( + "add", [self, self._cast_to_expr_or_convert_to_constant(other)] + ) + + @expose_as_static + def subtract(self, other: Expression | float) -> "Expression": + """Creates an expression that subtracts another expression or constant from this expression. + + Example: + >>> # Subtract the 'discount' field from the 'price' field + >>> Field.of("price").subtract(Field.of("discount")) + >>> # Subtract 20 from the value of the 'total' field + >>> Field.of("total").subtract(20) + + Args: + other: The expression or constant value to subtract from this expression. + + Returns: + A new `Expression` representing the subtraction operation. + """ + return FunctionExpression( + "subtract", [self, self._cast_to_expr_or_convert_to_constant(other)] + ) + + @expose_as_static + def multiply(self, other: Expression | float) -> "Expression": + """Creates an expression that multiplies this expression by another expression or constant. + + Example: + >>> # Multiply the 'quantity' field by the 'price' field + >>> Field.of("quantity").multiply(Field.of("price")) + >>> # Multiply the 'value' field by 2 + >>> Field.of("value").multiply(2) + + Args: + other: The expression or constant value to multiply by. + + Returns: + A new `Expression` representing the multiplication operation. + """ + return FunctionExpression( + "multiply", [self, self._cast_to_expr_or_convert_to_constant(other)] + ) + + @expose_as_static + def divide(self, other: Expression | float) -> "Expression": + """Creates an expression that divides this expression by another expression or constant. + + Example: + >>> # Divide the 'total' field by the 'count' field + >>> Field.of("total").divide(Field.of("count")) + >>> # Divide the 'value' field by 10 + >>> Field.of("value").divide(10) + + Args: + other: The expression or constant value to divide by. + + Returns: + A new `Expression` representing the division operation. + """ + return FunctionExpression( + "divide", [self, self._cast_to_expr_or_convert_to_constant(other)] + ) + + @expose_as_static + def mod(self, other: Expression | float) -> "Expression": + """Creates an expression that calculates the modulo (remainder) to another expression or constant. + + Example: + >>> # Calculate the remainder of dividing the 'value' field by field 'divisor'. + >>> Field.of("value").mod(Field.of("divisor")) + >>> # Calculate the remainder of dividing the 'value' field by 5. + >>> Field.of("value").mod(5) + + Args: + other: The divisor expression or constant. + + Returns: + A new `Expression` representing the modulo operation. + """ + return FunctionExpression( + "mod", [self, self._cast_to_expr_or_convert_to_constant(other)] + ) + + @expose_as_static + def abs(self) -> "Expression": + """Creates an expression that calculates the absolute value of this expression. + + Example: + >>> # Get the absolute value of the 'change' field. + >>> Field.of("change").abs() + + Returns: + A new `Expression` representing the absolute value. + """ + return FunctionExpression("abs", [self]) + + @expose_as_static + def ceil(self) -> "Expression": + """Creates an expression that calculates the ceiling of this expression. + + Example: + >>> # Get the ceiling of the 'value' field. + >>> Field.of("value").ceil() + + Returns: + A new `Expression` representing the ceiling value. + """ + return FunctionExpression("ceil", [self]) + + @expose_as_static + def exp(self) -> "Expression": + """Creates an expression that computes e to the power of this expression. + + Example: + >>> # Compute e to the power of the 'value' field + >>> Field.of("value").exp() + + Returns: + A new `Expression` representing the exponential value. + """ + return FunctionExpression("exp", [self]) + + @expose_as_static + def floor(self) -> "Expression": + """Creates an expression that calculates the floor of this expression. + + Example: + >>> # Get the floor of the 'value' field. + >>> Field.of("value").floor() + + Returns: + A new `Expression` representing the floor value. + """ + return FunctionExpression("floor", [self]) + + @expose_as_static + def ln(self) -> "Expression": + """Creates an expression that calculates the natural logarithm of this expression. + + Example: + >>> # Get the natural logarithm of the 'value' field. + >>> Field.of("value").ln() + + Returns: + A new `Expression` representing the natural logarithm. + """ + return FunctionExpression("ln", [self]) + + @expose_as_static + def log(self, base: Expression | float) -> "Expression": + """Creates an expression that calculates the logarithm of this expression with a given base. + + Example: + >>> # Get the logarithm of 'value' with base 2. + >>> Field.of("value").log(2) + >>> # Get the logarithm of 'value' with base from 'base_field'. + >>> Field.of("value").log(Field.of("base_field")) + + Args: + base: The base of the logarithm. + + Returns: + A new `Expression` representing the logarithm. + """ + return FunctionExpression( + "log", [self, self._cast_to_expr_or_convert_to_constant(base)] + ) + + @expose_as_static + def log10(self) -> "Expression": + """Creates an expression that calculates the base 10 logarithm of this expression. + + Example: + >>> Field.of("value").log10() + + Returns: + A new `Expression` representing the logarithm. + """ + return FunctionExpression("log10", [self]) + + @expose_as_static + def pow(self, exponent: Expression | float) -> "Expression": + """Creates an expression that calculates this expression raised to the power of the exponent. + + Example: + >>> # Raise 'base_val' to the power of 2. + >>> Field.of("base_val").pow(2) + >>> # Raise 'base_val' to the power of 'exponent_val'. + >>> Field.of("base_val").pow(Field.of("exponent_val")) + + Args: + exponent: The exponent. + + Returns: + A new `Expression` representing the power operation. + """ + return FunctionExpression( + "pow", [self, self._cast_to_expr_or_convert_to_constant(exponent)] + ) + + @expose_as_static + def round(self) -> "Expression": + """Creates an expression that rounds this expression to the nearest integer. + + Example: + >>> # Round the 'value' field. + >>> Field.of("value").round() + + Returns: + A new `Expression` representing the rounded value. + """ + return FunctionExpression("round", [self]) + + @expose_as_static + def sqrt(self) -> "Expression": + """Creates an expression that calculates the square root of this expression. + + Example: + >>> # Get the square root of the 'area' field. + >>> Field.of("area").sqrt() + + Returns: + A new `Expression` representing the square root. + """ + return FunctionExpression("sqrt", [self]) + + @expose_as_static + def logical_maximum(self, *others: Expression | CONSTANT_TYPE) -> "Expression": + """Creates an expression that returns the larger value between this expression + and another expression or constant, based on Firestore's value type ordering. + + Firestore's value type ordering is described here: + https://cloud.google.com/firestore/docs/concepts/data-types#value_type_ordering + + Example: + >>> # Returns the larger value between the 'discount' field and the 'cap' field. + >>> Field.of("discount").logical_maximum(Field.of("cap")) + >>> # Returns the larger value between the 'value' field and some ints + >>> Field.of("value").logical_maximum(10, 20, 30) + + Args: + others: The other expression or constant values to compare with. + + Returns: + A new `Expression` representing the logical maximum operation. + """ + return FunctionExpression( + "maximum", + [self] + [self._cast_to_expr_or_convert_to_constant(o) for o in others], + infix_name_override="logical_maximum", + ) + + @expose_as_static + def logical_minimum(self, *others: Expression | CONSTANT_TYPE) -> "Expression": + """Creates an expression that returns the smaller value between this expression + and another expression or constant, based on Firestore's value type ordering. + + Firestore's value type ordering is described here: + https://cloud.google.com/firestore/docs/concepts/data-types#value_type_ordering + + Example: + >>> # Returns the smaller value between the 'discount' field and the 'floor' field. + >>> Field.of("discount").logical_minimum(Field.of("floor")) + >>> # Returns the smaller value between the 'value' field and some ints + >>> Field.of("value").logical_minimum(10, 20, 30) + + Args: + others: The other expression or constant values to compare with. + + Returns: + A new `Expression` representing the logical minimum operation. + """ + return FunctionExpression( + "minimum", + [self] + [self._cast_to_expr_or_convert_to_constant(o) for o in others], + infix_name_override="logical_minimum", + ) + + @expose_as_static + def equal(self, other: Expression | CONSTANT_TYPE) -> "BooleanExpression": + """Creates an expression that checks if this expression is equal to another + expression or constant value. + + Example: + >>> # Check if the 'age' field is equal to 21 + >>> Field.of("age").equal(21) + >>> # Check if the 'city' field is equal to "London" + >>> Field.of("city").equal("London") + + Args: + other: The expression or constant value to compare for equality. + + Returns: + A new `Expression` representing the equality comparison. + """ + return BooleanExpression( + "equal", [self, self._cast_to_expr_or_convert_to_constant(other)] + ) + + @expose_as_static + def not_equal(self, other: Expression | CONSTANT_TYPE) -> "BooleanExpression": + """Creates an expression that checks if this expression is not equal to another + expression or constant value. + + Example: + >>> # Check if the 'status' field is not equal to "completed" + >>> Field.of("status").not_equal("completed") + >>> # Check if the 'country' field is not equal to "USA" + >>> Field.of("country").not_equal("USA") + + Args: + other: The expression or constant value to compare for inequality. + + Returns: + A new `Expression` representing the inequality comparison. + """ + return BooleanExpression( + "not_equal", [self, self._cast_to_expr_or_convert_to_constant(other)] + ) + + @expose_as_static + def greater_than(self, other: Expression | CONSTANT_TYPE) -> "BooleanExpression": + """Creates an expression that checks if this expression is greater than another + expression or constant value. + + Example: + >>> # Check if the 'age' field is greater than the 'limit' field + >>> Field.of("age").greater_than(Field.of("limit")) + >>> # Check if the 'price' field is greater than 100 + >>> Field.of("price").greater_than(100) + + Args: + other: The expression or constant value to compare for greater than. + + Returns: + A new `Expression` representing the greater than comparison. + """ + return BooleanExpression( + "greater_than", [self, self._cast_to_expr_or_convert_to_constant(other)] + ) + + @expose_as_static + def greater_than_or_equal( + self, other: Expression | CONSTANT_TYPE + ) -> "BooleanExpression": + """Creates an expression that checks if this expression is greater than or equal + to another expression or constant value. + + Example: + >>> # Check if the 'quantity' field is greater than or equal to field 'requirement' plus 1 + >>> Field.of("quantity").greater_than_or_equal(Field.of('requirement').add(1)) + >>> # Check if the 'score' field is greater than or equal to 80 + >>> Field.of("score").greater_than_or_equal(80) + + Args: + other: The expression or constant value to compare for greater than or equal to. + + Returns: + A new `Expression` representing the greater than or equal to comparison. + """ + return BooleanExpression( + "greater_than_or_equal", + [self, self._cast_to_expr_or_convert_to_constant(other)], + ) + + @expose_as_static + def less_than(self, other: Expression | CONSTANT_TYPE) -> "BooleanExpression": + """Creates an expression that checks if this expression is less than another + expression or constant value. + + Example: + >>> # Check if the 'age' field is less than 'limit' + >>> Field.of("age").less_than(Field.of('limit')) + >>> # Check if the 'price' field is less than 50 + >>> Field.of("price").less_than(50) + + Args: + other: The expression or constant value to compare for less than. + + Returns: + A new `Expression` representing the less than comparison. + """ + return BooleanExpression( + "less_than", [self, self._cast_to_expr_or_convert_to_constant(other)] + ) + + @expose_as_static + def less_than_or_equal( + self, other: Expression | CONSTANT_TYPE + ) -> "BooleanExpression": + """Creates an expression that checks if this expression is less than or equal to + another expression or constant value. + + Example: + >>> # Check if the 'quantity' field is less than or equal to 20 + >>> Field.of("quantity").less_than_or_equal(Constant.of(20)) + >>> # Check if the 'score' field is less than or equal to 70 + >>> Field.of("score").less_than_or_equal(70) + + Args: + other: The expression or constant value to compare for less than or equal to. + + Returns: + A new `Expression` representing the less than or equal to comparison. + """ + return BooleanExpression( + "less_than_or_equal", + [self, self._cast_to_expr_or_convert_to_constant(other)], + ) + + @expose_as_static + def equal_any( + self, array: Array | Sequence[Expression | CONSTANT_TYPE] | Expression + ) -> "BooleanExpression": + """Creates an expression that checks if this expression is equal to any of the + provided values or expressions. + + Example: + >>> # Check if the 'category' field is either "Electronics" or value of field 'primaryType' + >>> Field.of("category").equal_any(["Electronics", Field.of("primaryType")]) + + Args: + array: The values or expressions to check against. + + Returns: + A new `Expression` representing the 'IN' comparison. + """ + return BooleanExpression( + "equal_any", + [ + self, + self._cast_to_expr_or_convert_to_constant(array), + ], + ) + + @expose_as_static + def not_equal_any( + self, array: Array | list[Expression | CONSTANT_TYPE] | Expression + ) -> "BooleanExpression": + """Creates an expression that checks if this expression is not equal to any of the + provided values or expressions. + + Example: + >>> # Check if the 'status' field is neither "pending" nor "cancelled" + >>> Field.of("status").not_equal_any(["pending", "cancelled"]) + + Args: + array: The values or expressions to check against. + + Returns: + A new `Expression` representing the 'NOT IN' comparison. + """ + return BooleanExpression( + "not_equal_any", + [ + self, + self._cast_to_expr_or_convert_to_constant(array), + ], + ) + + @expose_as_static + def array_get(self, offset: Expression | int) -> "FunctionExpression": + """ + Creates an expression that indexes into an array from the beginning or end and returns the + element. A negative offset starts from the end. + + Example: + >>> Array([1,2,3]).array_get(0) + + Args: + offset: the index of the element to return + + Returns: + A new `Expression` representing the `array_get` operation. + """ + return FunctionExpression( + "array_get", [self, self._cast_to_expr_or_convert_to_constant(offset)] + ) + + @expose_as_static + def array_contains( + self, element: Expression | CONSTANT_TYPE + ) -> "BooleanExpression": + """Creates an expression that checks if an array contains a specific element or value. + + Example: + >>> # Check if the 'sizes' array contains the value from the 'selectedSize' field + >>> Field.of("sizes").array_contains(Field.of("selectedSize")) + >>> # Check if the 'colors' array contains "red" + >>> Field.of("colors").array_contains("red") + + Args: + element: The element (expression or constant) to search for in the array. + + Returns: + A new `Expression` representing the 'array_contains' comparison. + """ + return BooleanExpression( + "array_contains", [self, self._cast_to_expr_or_convert_to_constant(element)] + ) + + @expose_as_static + def array_contains_all( + self, + elements: Array | list[Expression | CONSTANT_TYPE] | Expression, + ) -> "BooleanExpression": + """Creates an expression that checks if an array contains all the specified elements. + + Example: + >>> # Check if the 'tags' array contains both "news" and "sports" + >>> Field.of("tags").array_contains_all(["news", "sports"]) + >>> # Check if the 'tags' array contains both of the values from field 'tag1' and "tag2" + >>> Field.of("tags").array_contains_all([Field.of("tag1"), "tag2"]) + + Args: + elements: The list of elements (expressions or constants) to check for in the array. + + Returns: + A new `Expression` representing the 'array_contains_all' comparison. + """ + return BooleanExpression( + "array_contains_all", + [ + self, + self._cast_to_expr_or_convert_to_constant(elements), + ], + ) + + @expose_as_static + def array_contains_any( + self, + elements: Array | list[Expression | CONSTANT_TYPE] | Expression, + ) -> "BooleanExpression": + """Creates an expression that checks if an array contains any of the specified elements. + + Example: + >>> # Check if the 'categories' array contains either values from field "cate1" or "cate2" + >>> Field.of("categories").array_contains_any([Field.of("cate1"), Field.of("cate2")]) + >>> # Check if the 'groups' array contains either the value from the 'userGroup' field + >>> # or the value "guest" + >>> Field.of("groups").array_contains_any([Field.of("userGroup"), "guest"]) + + Args: + elements: The list of elements (expressions or constants) to check for in the array. + + Returns: + A new `Expression` representing the 'array_contains_any' comparison. + """ + return BooleanExpression( + "array_contains_any", + [ + self, + self._cast_to_expr_or_convert_to_constant(elements), + ], + ) + + @expose_as_static + def array_length(self) -> "Expression": + """Creates an expression that calculates the length of an array. + + Example: + >>> # Get the number of items in the 'cart' array + >>> Field.of("cart").array_length() + + Returns: + A new `Expression` representing the length of the array. + """ + return FunctionExpression("array_length", [self]) + + @expose_as_static + def array_reverse(self) -> "Expression": + """Creates an expression that returns the reversed content of an array. + + Example: + >>> # Get the 'preferences' array in reversed order. + >>> Field.of("preferences").array_reverse() + + Returns: + A new `Expression` representing the reversed array. + """ + return FunctionExpression("array_reverse", [self]) + + @expose_as_static + def array_concat( + self, *other_arrays: Array | list[Expression | CONSTANT_TYPE] | Expression + ) -> "Expression": + """Creates an expression that concatenates an array expression with another array. + + Example: + >>> # Combine the 'tags' array with a new array and an array field + >>> Field.of("tags").array_concat(["newTag1", "newTag2", Field.of("otherTag")]) + + Args: + array: The list of constants or expressions to concat with. + + Returns: + A new `Expression` representing the concatenated array. + """ + return FunctionExpression( + "array_concat", + [self] + + [self._cast_to_expr_or_convert_to_constant(arr) for arr in other_arrays], + ) + + @expose_as_static + def concat(self, *others: Expression | CONSTANT_TYPE) -> "Expression": + """Creates an expression that concatenates expressions together + + Args: + *others: The expressions to concatenate. + + Returns: + A new `Expression` representing the concatenated value. + """ + return FunctionExpression( + "concat", + [self] + [self._cast_to_expr_or_convert_to_constant(o) for o in others], + ) + + @expose_as_static + def length(self) -> "Expression": + """ + Creates an expression that calculates the length of the expression if it is a string, array, map, or blob. + + Example: + >>> # Get the length of the 'name' field. + >>> Field.of("name").length() + + Returns: + A new `Expression` representing the length of the expression. + """ + return FunctionExpression("length", [self]) + + @expose_as_static + def is_absent(self) -> "BooleanExpression": + """Creates an expression that returns true if a value is absent. Otherwise, returns false even if + the value is null. + + Example: + >>> # Check if the 'email' field is absent. + >>> Field.of("email").is_absent() + + Returns: + A new `BooleanExpressionession` representing the isAbsent operation. + """ + return BooleanExpression("is_absent", [self]) + + @expose_as_static + def if_absent(self, default_value: Expression | CONSTANT_TYPE) -> "Expression": + """Creates an expression that returns a default value if an expression evaluates to an absent value. + + Example: + >>> # Return the value of the 'email' field, or "N/A" if it's absent. + >>> Field.of("email").if_absent("N/A") + + Args: + default_value: The expression or constant value to return if this expression is absent. + + Returns: + A new `Expression` representing the ifAbsent operation. + """ + return FunctionExpression( + "if_absent", + [self, self._cast_to_expr_or_convert_to_constant(default_value)], + ) + + @expose_as_static + def is_error(self): + """Creates an expression that checks if a given expression produces an error + + Example: + >>> # Resolves to True if an expression produces an error + >>> Field.of("value").divide("string").is_error() + + Returns: + A new `Expression` representing the isError operation. + """ + return FunctionExpression("is_error", [self]) + + @expose_as_static + def if_error(self, then_value: Expression | CONSTANT_TYPE) -> "Expression": + """Creates an expression that returns ``then_value`` if this expression evaluates to an error. + Otherwise, returns the value of this expression. + + Example: + >>> # Resolves to 0 if an expression produces an error + >>> Field.of("value").divide("string").if_error(0) + + Args: + then_value: The value to return if this expression evaluates to an error. + + Returns: + A new `Expression` representing the ifError operation. + """ + return FunctionExpression( + "if_error", [self, self._cast_to_expr_or_convert_to_constant(then_value)] + ) + + @expose_as_static + def exists(self) -> "BooleanExpression": + """Creates an expression that checks if a field exists in the document. + + Example: + >>> # Check if the document has a field named "phoneNumber" + >>> Field.of("phoneNumber").exists() + + Returns: + A new `Expression` representing the 'exists' check. + """ + return BooleanExpression("exists", [self]) + + @expose_as_static + def sum(self) -> "Expression": + """Creates an aggregation that calculates the sum of a numeric field across multiple stage inputs. + + Example: + >>> # Calculate the total revenue from a set of orders + >>> Field.of("orderAmount").sum().as_("totalRevenue") + + Returns: + A new `AggregateFunction` representing the 'sum' aggregation. + """ + return AggregateFunction("sum", [self]) + + @expose_as_static + def average(self) -> "Expression": + """Creates an aggregation that calculates the average (mean) of a numeric field across multiple + stage inputs. + + Example: + >>> # Calculate the average age of users + >>> Field.of("age").average().as_("averageAge") + + Returns: + A new `AggregateFunction` representing the 'avg' aggregation. + """ + return AggregateFunction("average", [self]) + + @expose_as_static + def count(self) -> "Expression": + """Creates an aggregation that counts the number of stage inputs with valid evaluations of the + expression or field. + + Example: + >>> # Count the total number of products + >>> Field.of("productId").count().as_("totalProducts") + + Returns: + A new `AggregateFunction` representing the 'count' aggregation. + """ + return AggregateFunction("count", [self]) + + @expose_as_static + def count_if(self) -> "Expression": + """Creates an aggregation that counts the number of values of the provided field or expression + that evaluate to True. + + Example: + >>> # Count the number of adults + >>> Field.of("age").greater_than(18).count_if().as_("totalAdults") + + + Returns: + A new `AggregateFunction` representing the 'count_if' aggregation. + """ + return AggregateFunction("count_if", [self]) + + @expose_as_static + def count_distinct(self) -> "Expression": + """Creates an aggregation that counts the number of distinct values of the + provided field or expression. + + Example: + >>> # Count the total number of countries in the data + >>> Field.of("country").count_distinct().as_("totalCountries") + + Returns: + A new `AggregateFunction` representing the 'count_distinct' aggregation. + """ + return AggregateFunction("count_distinct", [self]) + + @expose_as_static + def minimum(self) -> "Expression": + """Creates an aggregation that finds the minimum value of a field across multiple stage inputs. + + Example: + >>> # Find the lowest price of all products + >>> Field.of("price").minimum().as_("lowestPrice") + + Returns: + A new `AggregateFunction` representing the 'minimum' aggregation. + """ + return AggregateFunction("minimum", [self]) + + @expose_as_static + def maximum(self) -> "Expression": + """Creates an aggregation that finds the maximum value of a field across multiple stage inputs. + + Example: + >>> # Find the highest score in a leaderboard + >>> Field.of("score").maximum().as_("highestScore") + + Returns: + A new `AggregateFunction` representing the 'maximum' aggregation. + """ + return AggregateFunction("maximum", [self]) + + @expose_as_static + def char_length(self) -> "Expression": + """Creates an expression that calculates the character length of a string. + + Example: + >>> # Get the character length of the 'name' field + >>> Field.of("name").char_length() + + Returns: + A new `Expression` representing the length of the string. + """ + return FunctionExpression("char_length", [self]) + + @expose_as_static + def byte_length(self) -> "Expression": + """Creates an expression that calculates the byte length of a string in its UTF-8 form. + + Example: + >>> # Get the byte length of the 'name' field + >>> Field.of("name").byte_length() + + Returns: + A new `Expression` representing the byte length of the string. + """ + return FunctionExpression("byte_length", [self]) + + @expose_as_static + def like(self, pattern: Expression | str) -> "BooleanExpression": + """Creates an expression that performs a case-sensitive string comparison. + + Example: + >>> # Check if the 'title' field contains the word "guide" (case-sensitive) + >>> Field.of("title").like("%guide%") + >>> # Check if the 'title' field matches the pattern specified in field 'pattern'. + >>> Field.of("title").like(Field.of("pattern")) + + Args: + pattern: The pattern (string or expression) to search for. You can use "%" as a wildcard character. + + Returns: + A new `Expression` representing the 'like' comparison. + """ + return BooleanExpression( + "like", [self, self._cast_to_expr_or_convert_to_constant(pattern)] + ) + + @expose_as_static + def regex_contains(self, regex: Expression | str) -> "BooleanExpression": + """Creates an expression that checks if a string contains a specified regular expression as a + substring. + + Example: + >>> # Check if the 'description' field contains "example" (case-insensitive) + >>> Field.of("description").regex_contains("(?i)example") + >>> # Check if the 'description' field contains the regular expression stored in field 'regex' + >>> Field.of("description").regex_contains(Field.of("regex")) + + Args: + regex: The regular expression (string or expression) to use for the search. + + Returns: + A new `Expression` representing the 'contains' comparison. + """ + return BooleanExpression( + "regex_contains", [self, self._cast_to_expr_or_convert_to_constant(regex)] + ) + + @expose_as_static + def regex_match(self, regex: Expression | str) -> "BooleanExpression": + """Creates an expression that checks if a string matches a specified regular expression. + + Example: + >>> # Check if the 'email' field matches a valid email pattern + >>> Field.of("email").regex_match("[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\\.[A-Za-z]{2,}") + >>> # Check if the 'email' field matches a regular expression stored in field 'regex' + >>> Field.of("email").regex_match(Field.of("regex")) + + Args: + regex: The regular expression (string or expression) to use for the match. + + Returns: + A new `Expression` representing the regular expression match. + """ + return BooleanExpression( + "regex_match", [self, self._cast_to_expr_or_convert_to_constant(regex)] + ) + + @expose_as_static + def string_contains(self, substring: Expression | str) -> "BooleanExpression": + """Creates an expression that checks if this string expression contains a specified substring. + + Example: + >>> # Check if the 'description' field contains "example". + >>> Field.of("description").string_contains("example") + >>> # Check if the 'description' field contains the value of the 'keyword' field. + >>> Field.of("description").string_contains(Field.of("keyword")) + + Args: + substring: The substring (string or expression) to use for the search. + + Returns: + A new `Expression` representing the 'contains' comparison. + """ + return BooleanExpression( + "string_contains", + [self, self._cast_to_expr_or_convert_to_constant(substring)], + ) + + @expose_as_static + def starts_with(self, prefix: Expression | str) -> "BooleanExpression": + """Creates an expression that checks if a string starts with a given prefix. + + Example: + >>> # Check if the 'name' field starts with "Mr." + >>> Field.of("name").starts_with("Mr.") + >>> # Check if the 'fullName' field starts with the value of the 'firstName' field + >>> Field.of("fullName").starts_with(Field.of("firstName")) + + Args: + prefix: The prefix (string or expression) to check for. + + Returns: + A new `Expression` representing the 'starts with' comparison. + """ + return BooleanExpression( + "starts_with", [self, self._cast_to_expr_or_convert_to_constant(prefix)] + ) + + @expose_as_static + def ends_with(self, postfix: Expression | str) -> "BooleanExpression": + """Creates an expression that checks if a string ends with a given postfix. + + Example: + >>> # Check if the 'filename' field ends with ".txt" + >>> Field.of("filename").ends_with(".txt") + >>> # Check if the 'url' field ends with the value of the 'extension' field + >>> Field.of("url").ends_with(Field.of("extension")) + + Args: + postfix: The postfix (string or expression) to check for. + + Returns: + A new `Expression` representing the 'ends with' comparison. + """ + return BooleanExpression( + "ends_with", [self, self._cast_to_expr_or_convert_to_constant(postfix)] + ) + + @expose_as_static + def string_concat(self, *elements: Expression | CONSTANT_TYPE) -> "Expression": + """Creates an expression that concatenates string expressions, fields or constants together. + + Example: + >>> # Combine the 'firstName', " ", and 'lastName' fields into a single string + >>> Field.of("firstName").string_concat(" ", Field.of("lastName")) + + Args: + *elements: The expressions or constants (typically strings) to concatenate. + + Returns: + A new `Expression` representing the concatenated string. + """ + return FunctionExpression( + "string_concat", + [self] + [self._cast_to_expr_or_convert_to_constant(el) for el in elements], + ) + + @expose_as_static + def to_lower(self) -> "Expression": + """Creates an expression that converts a string to lowercase. + + Example: + >>> # Convert the 'name' field to lowercase + >>> Field.of("name").to_lower() + + Returns: + A new `Expression` representing the lowercase string. + """ + return FunctionExpression("to_lower", [self]) + + @expose_as_static + def to_upper(self) -> "Expression": + """Creates an expression that converts a string to uppercase. + + Example: + >>> # Convert the 'title' field to uppercase + >>> Field.of("title").to_upper() + + Returns: + A new `Expression` representing the uppercase string. + """ + return FunctionExpression("to_upper", [self]) + + @expose_as_static + def trim(self) -> "Expression": + """Creates an expression that removes leading and trailing whitespace from a string. + + Example: + >>> # Trim whitespace from the 'userInput' field + >>> Field.of("userInput").trim() + + Returns: + A new `Expression` representing the trimmed string. + """ + return FunctionExpression("trim", [self]) + + @expose_as_static + def string_reverse(self) -> "Expression": + """Creates an expression that reverses a string. + + Example: + >>> # Reverse the 'userInput' field + >>> Field.of("userInput").reverse() + + Returns: + A new `Expression` representing the reversed string. + """ + return FunctionExpression("string_reverse", [self]) + + @expose_as_static + def substring( + self, position: Expression | int, length: Expression | int | None = None + ) -> "Expression": + """Creates an expression that returns a substring of the results of this expression. + + + Example: + >>> Field.of("description").substring(5, 10) + >>> Field.of("description").substring(5) + + Args: + position: the index of the first character of the substring. + length: the length of the substring. If not provided the substring + will end at the end of the input. + + Returns: + A new `Expression` representing the extracted substring. + """ + args = [self, self._cast_to_expr_or_convert_to_constant(position)] + if length is not None: + args.append(self._cast_to_expr_or_convert_to_constant(length)) + return FunctionExpression("substring", args) + + @expose_as_static + def join(self, delimeter: Expression | str) -> "Expression": + """Creates an expression that joins the elements of an array into a string + + + Example: + >>> Field.of("tags").join(", ") + + Args: + delimiter: The delimiter to add between the elements of the array. + + Returns: + A new `Expression` representing the joined string. + """ + return FunctionExpression( + "join", [self, self._cast_to_expr_or_convert_to_constant(delimeter)] + ) + + @expose_as_static + def map_get(self, key: str | Constant[str]) -> "Expression": + """Accesses a value from the map produced by evaluating this expression. + + Example: + >>> Map({"city": "London"}).map_get("city") + >>> Field.of("address").map_get("city") + + Args: + key: The key to access in the map. + + Returns: + A new `Expression` representing the value associated with the given key in the map. + """ + return FunctionExpression( + "map_get", [self, self._cast_to_expr_or_convert_to_constant(key)] + ) + + @expose_as_static + def map_remove(self, key: str | Constant[str]) -> "Expression": + """Remove a key from a the map produced by evaluating this expression. + + Example: + >>> Map({"city": "London"}).map_remove("city") + >>> Field.of("address").map_remove("city") + + Args: + key: The key to remove in the map. + + Returns: + A new `Expression` representing the map_remove operation. + """ + return FunctionExpression( + "map_remove", [self, self._cast_to_expr_or_convert_to_constant(key)] + ) + + @expose_as_static + def map_merge( + self, + *other_maps: Map + | dict[str | Constant[str], Expression | CONSTANT_TYPE] + | Expression, + ) -> "Expression": + """Creates an expression that merges one or more dicts into a single map. + + Example: + >>> Map({"city": "London"}).map_merge({"country": "UK"}, {"isCapital": True}) + >>> Field.of("settings").map_merge({"enabled":True}, FunctionExpression.conditional(Field.of('isAdmin'), {"admin":True}, {}}) + + Args: + *other_maps: Sequence of maps to merge into the resulting map. + + Returns: + A new `Expression` representing the value associated with the given key in the map. + """ + return FunctionExpression( + "map_merge", + [self] + [self._cast_to_expr_or_convert_to_constant(m) for m in other_maps], + ) + + @expose_as_static + def cosine_distance(self, other: Expression | list[float] | Vector) -> "Expression": + """Calculates the cosine distance between two vectors. + + Example: + >>> # Calculate the cosine distance between the 'userVector' field and the 'itemVector' field + >>> Field.of("userVector").cosine_distance(Field.of("itemVector")) + >>> # Calculate the Cosine distance between the 'location' field and a target location + >>> Field.of("location").cosine_distance([37.7749, -122.4194]) + + Args: + other: The other vector (represented as an Expression, list of floats, or Vector) to compare against. + + Returns: + A new `Expression` representing the cosine distance between the two vectors. + """ + return FunctionExpression( + "cosine_distance", + [ + self, + self._cast_to_expr_or_convert_to_constant(other, include_vector=True), + ], + ) + + @expose_as_static + def euclidean_distance( + self, other: Expression | list[float] | Vector + ) -> "Expression": + """Calculates the Euclidean distance between two vectors. + + Example: + >>> # Calculate the Euclidean distance between the 'location' field and a target location + >>> Field.of("location").euclidean_distance([37.7749, -122.4194]) + >>> # Calculate the Euclidean distance between two vector fields: 'pointA' and 'pointB' + >>> Field.of("pointA").euclidean_distance(Field.of("pointB")) + + Args: + other: The other vector (represented as an Expression, list of floats, or Vector) to compare against. + + Returns: + A new `Expression` representing the Euclidean distance between the two vectors. + """ + return FunctionExpression( + "euclidean_distance", + [ + self, + self._cast_to_expr_or_convert_to_constant(other, include_vector=True), + ], + ) + + @expose_as_static + def dot_product(self, other: Expression | list[float] | Vector) -> "Expression": + """Calculates the dot product between two vectors. + + Example: + >>> # Calculate the dot product between a feature vector and a target vector + >>> Field.of("features").dot_product([0.5, 0.8, 0.2]) + >>> # Calculate the dot product between two document vectors: 'docVector1' and 'docVector2' + >>> Field.of("docVector1").dot_product(Field.of("docVector2")) + + Args: + other: The other vector (represented as an Expression, list of floats, or Vector) to calculate dot product with. + + Returns: + A new `Expression` representing the dot product between the two vectors. + """ + return FunctionExpression( + "dot_product", + [ + self, + self._cast_to_expr_or_convert_to_constant(other, include_vector=True), + ], + ) + + @expose_as_static + def vector_length(self) -> "Expression": + """Creates an expression that calculates the length (dimension) of a Firestore Vector. + + Example: + >>> # Get the vector length (dimension) of the field 'embedding'. + >>> Field.of("embedding").vector_length() + + Returns: + A new `Expression` representing the length of the vector. + """ + return FunctionExpression("vector_length", [self]) + + @expose_as_static + def timestamp_to_unix_micros(self) -> "Expression": + """Creates an expression that converts a timestamp to the number of microseconds since the epoch + (1970-01-01 00:00:00 UTC). + + Truncates higher levels of precision by rounding down to the beginning of the microsecond. + + Example: + >>> # Convert the 'timestamp' field to microseconds since the epoch. + >>> Field.of("timestamp").timestamp_to_unix_micros() + + Returns: + A new `Expression` representing the number of microseconds since the epoch. + """ + return FunctionExpression("timestamp_to_unix_micros", [self]) + + @expose_as_static + def unix_micros_to_timestamp(self) -> "Expression": + """Creates an expression that converts a number of microseconds since the epoch (1970-01-01 + 00:00:00 UTC) to a timestamp. + + Example: + >>> # Convert the 'microseconds' field to a timestamp. + >>> Field.of("microseconds").unix_micros_to_timestamp() + + Returns: + A new `Expression` representing the timestamp. + """ + return FunctionExpression("unix_micros_to_timestamp", [self]) + + @expose_as_static + def timestamp_to_unix_millis(self) -> "Expression": + """Creates an expression that converts a timestamp to the number of milliseconds since the epoch + (1970-01-01 00:00:00 UTC). + + Truncates higher levels of precision by rounding down to the beginning of the millisecond. + + Example: + >>> # Convert the 'timestamp' field to milliseconds since the epoch. + >>> Field.of("timestamp").timestamp_to_unix_millis() + + Returns: + A new `Expression` representing the number of milliseconds since the epoch. + """ + return FunctionExpression("timestamp_to_unix_millis", [self]) + + @expose_as_static + def unix_millis_to_timestamp(self) -> "Expression": + """Creates an expression that converts a number of milliseconds since the epoch (1970-01-01 + 00:00:00 UTC) to a timestamp. + + Example: + >>> # Convert the 'milliseconds' field to a timestamp. + >>> Field.of("milliseconds").unix_millis_to_timestamp() + + Returns: + A new `Expression` representing the timestamp. + """ + return FunctionExpression("unix_millis_to_timestamp", [self]) + + @expose_as_static + def timestamp_to_unix_seconds(self) -> "Expression": + """Creates an expression that converts a timestamp to the number of seconds since the epoch + (1970-01-01 00:00:00 UTC). + + Truncates higher levels of precision by rounding down to the beginning of the second. + + Example: + >>> # Convert the 'timestamp' field to seconds since the epoch. + >>> Field.of("timestamp").timestamp_to_unix_seconds() + + Returns: + A new `Expression` representing the number of seconds since the epoch. + """ + return FunctionExpression("timestamp_to_unix_seconds", [self]) + + @expose_as_static + def unix_seconds_to_timestamp(self) -> "Expression": + """Creates an expression that converts a number of seconds since the epoch (1970-01-01 00:00:00 + UTC) to a timestamp. + + Example: + >>> # Convert the 'seconds' field to a timestamp. + >>> Field.of("seconds").unix_seconds_to_timestamp() + + Returns: + A new `Expression` representing the timestamp. + """ + return FunctionExpression("unix_seconds_to_timestamp", [self]) + + @expose_as_static + def timestamp_add( + self, unit: Expression | str, amount: Expression | float + ) -> "Expression": + """Creates an expression that adds a specified amount of time to this timestamp expression. + + Example: + >>> # Add a duration specified by the 'unit' and 'amount' fields to the 'timestamp' field. + >>> Field.of("timestamp").timestamp_add(Field.of("unit"), Field.of("amount")) + >>> # Add 1.5 days to the 'timestamp' field. + >>> Field.of("timestamp").timestamp_add("day", 1.5) + + Args: + unit: The expression or string evaluating to the unit of time to add, must be one of + 'microsecond', 'millisecond', 'second', 'minute', 'hour', 'day'. + amount: The expression or float representing the amount of time to add. + + Returns: + A new `Expression` representing the resulting timestamp. + """ + return FunctionExpression( + "timestamp_add", + [ + self, + self._cast_to_expr_or_convert_to_constant(unit), + self._cast_to_expr_or_convert_to_constant(amount), + ], + ) + + @expose_as_static + def timestamp_subtract( + self, unit: Expression | str, amount: Expression | float + ) -> "Expression": + """Creates an expression that subtracts a specified amount of time from this timestamp expression. + + Example: + >>> # Subtract a duration specified by the 'unit' and 'amount' fields from the 'timestamp' field. + >>> Field.of("timestamp").timestamp_subtract(Field.of("unit"), Field.of("amount")) + >>> # Subtract 2.5 hours from the 'timestamp' field. + >>> Field.of("timestamp").timestamp_subtract("hour", 2.5) + + Args: + unit: The expression or string evaluating to the unit of time to subtract, must be one of + 'microsecond', 'millisecond', 'second', 'minute', 'hour', 'day'. + amount: The expression or float representing the amount of time to subtract. + + Returns: + A new `Expression` representing the resulting timestamp. + """ + return FunctionExpression( + "timestamp_subtract", + [ + self, + self._cast_to_expr_or_convert_to_constant(unit), + self._cast_to_expr_or_convert_to_constant(amount), + ], + ) + + @expose_as_static + def collection_id(self): + """Creates an expression that returns the collection ID from a path. + + Example: + >>> # Get the collection ID from a path. + >>> Field.of("__name__").collection_id() + + Returns: + A new `Expression` representing the collection ID. + """ + return FunctionExpression("collection_id", [self]) + + @expose_as_static + def document_id(self): + """Creates an expression that returns the document ID from a path. + + Example: + >>> # Get the document ID from a path. + >>> Field.of("__name__").document_id() + + Returns: + A new `Expression` representing the document ID. + """ + return FunctionExpression("document_id", [self]) + + def ascending(self) -> Ordering: + """Creates an `Ordering` that sorts documents in ascending order based on this expression. + + Example: + >>> # Sort documents by the 'name' field in ascending order + >>> client.pipeline().collection("users").sort(Field.of("name").ascending()) + + Returns: + A new `Ordering` for ascending sorting. + """ + return Ordering(self, Ordering.Direction.ASCENDING) + + def descending(self) -> Ordering: + """Creates an `Ordering` that sorts documents in descending order based on this expression. + + Example: + >>> # Sort documents by the 'createdAt' field in descending order + >>> client.pipeline().collection("users").sort(Field.of("createdAt").descending()) + + Returns: + A new `Ordering` for descending sorting. + """ + return Ordering(self, Ordering.Direction.DESCENDING) + + def as_(self, alias: str) -> "AliasedExpression": + """Assigns an alias to this expression. + + Aliases are useful for renaming fields in the output of a stage or for giving meaningful + names to calculated values. + + Example: + >>> # Calculate the total price and assign it the alias "totalPrice" and add it to the output. + >>> client.pipeline().collection("items").add_fields( + ... Field.of("price").multiply(Field.of("quantity")).as_("totalPrice") + ... ) + + Args: + alias: The alias to assign to this expression. + + Returns: + A new `Selectable` (typically an `AliasedExpression`) that wraps this + expression and associates it with the provided alias. + """ + return AliasedExpression(self, alias) + + +class Constant(Expression, Generic[CONSTANT_TYPE]): + """Represents a constant literal value in an expression.""" + + def __init__(self, value: CONSTANT_TYPE): + self.value: CONSTANT_TYPE = value + + def __eq__(self, other): + if not isinstance(other, Constant): + return other == self.value + else: + return other.value == self.value + + @staticmethod + def of(value: CONSTANT_TYPE) -> Constant[CONSTANT_TYPE]: + """Creates a constant expression from a Python value.""" + return Constant(value) + + def __repr__(self): + value_str = repr(self.value) + if isinstance(self.value, float) and value_str == "nan": + value_str = "math.nan" + return f"Constant.of({value_str})" + + def __hash__(self): + return hash(self.value) + + def _to_pb(self) -> Value: + return encode_value(self.value) + + +class FunctionExpression(Expression): + """A base class for expressions that represent function calls.""" + + def __init__( + self, + name: str, + params: Sequence[Expression], + *, + use_infix_repr: bool = True, + infix_name_override: str | None = None, + ): + self.name = name + self.params = list(params) + self._use_infix_repr = use_infix_repr + self._infix_name_override = infix_name_override + + def __repr__(self): + """ + Most FunctionExpressions can be triggered infix. Eg: Field.of('age').greater_than(18). + + Display them this way in the repr string where possible + """ + if self._use_infix_repr: + infix_name = self._infix_name_override or self.name + if len(self.params) == 1: + return f"{self.params[0]!r}.{infix_name}()" + elif len(self.params) == 2: + return f"{self.params[0]!r}.{infix_name}({self.params[1]!r})" + else: + return f"{self.params[0]!r}.{infix_name}({', '.join([repr(p) for p in self.params[1:]])})" + return f"{self.__class__.__name__}({', '.join([repr(p) for p in self.params])})" + + def __eq__(self, other): + if not isinstance(other, FunctionExpression): + return False + else: + return other.name == self.name and other.params == self.params + + def _to_pb(self): + return Value( + function_value={ + "name": self.name, + "args": [p._to_pb() for p in self.params], + } + ) + + +class AggregateFunction(FunctionExpression): + """A base class for aggregation functions that operate across multiple inputs.""" + + +class Selectable(Expression): + """Base class for expressions that can be selected or aliased in projection stages.""" + + def __eq__(self, other): + if not isinstance(other, type(self)): + return False + else: + return other._to_map() == self._to_map() + + @abstractmethod + def _to_map(self) -> tuple[str, Value]: + """ + Returns a str: Value representation of the Selectable + """ + raise NotImplementedError + + @classmethod + def _value_from_selectables(cls, *selectables: Selectable) -> Value: + """ + Returns a Value representing a map of Selectables + """ + return Value( + map_value={ + "fields": {m[0]: m[1] for m in [s._to_map() for s in selectables]} + } + ) + + @staticmethod + def _to_value(field_list: Sequence[Selectable]) -> Value: + return Value( + map_value={ + "fields": {m[0]: m[1] for m in [f._to_map() for f in field_list]} + } + ) + + +T = TypeVar("T", bound=Expression) + + +class AliasedExpression(Selectable, Generic[T]): + """Wraps an expression with an alias.""" + + def __init__(self, expr: T, alias: str): + self.expr = expr + self.alias = alias + + def _to_map(self): + return self.alias, self.expr._to_pb() + + def __repr__(self): + return f"{self.expr}.as_('{self.alias}')" + + def _to_pb(self): + return Value(map_value={"fields": {self.alias: self.expr._to_pb()}}) + + +class Field(Selectable): + """Represents a reference to a field within a document.""" + + DOCUMENT_ID = "__name__" + + def __init__(self, path: str): + """Initializes a Field reference. + + Args: + path: The dot-separated path to the field (e.g., "address.city"). + Use Field.DOCUMENT_ID for the document ID. + """ + self.path = path + + @staticmethod + def of(path: str): + """Creates a Field reference. + + Args: + path: The dot-separated path to the field (e.g., "address.city"). + Use Field.DOCUMENT_ID for the document ID. + + Returns: + A new Field instance. + """ + return Field(path) + + def _to_map(self): + return self.path, self._to_pb() + + def __repr__(self): + return f"Field.of({self.path!r})" + + def _to_pb(self): + return Value(field_reference_value=self.path) + + +class BooleanExpression(FunctionExpression): + """Filters the given data in some way.""" + + @staticmethod + def _from_query_filter_pb(filter_pb, client): + if isinstance(filter_pb, Query_pb.CompositeFilter): + sub_filters = [ + BooleanExpression._from_query_filter_pb(f, client) + for f in filter_pb.filters + ] + if filter_pb.op == Query_pb.CompositeFilter.Operator.OR: + return Or(*sub_filters) + elif filter_pb.op == Query_pb.CompositeFilter.Operator.AND: + return And(*sub_filters) + else: + raise TypeError( + f"Unexpected CompositeFilter operator type: {filter_pb.op}" + ) + elif isinstance(filter_pb, Query_pb.UnaryFilter): + field = Field.of(filter_pb.field.field_path) + if filter_pb.op == Query_pb.UnaryFilter.Operator.IS_NAN: + return And(field.exists(), field.equal(float("nan"))) + elif filter_pb.op == Query_pb.UnaryFilter.Operator.IS_NOT_NAN: + return And(field.exists(), Not(field.equal(float("nan")))) + elif filter_pb.op == Query_pb.UnaryFilter.Operator.IS_NULL: + return And(field.exists(), field.equal(None)) + elif filter_pb.op == Query_pb.UnaryFilter.Operator.IS_NOT_NULL: + return And(field.exists(), Not(field.equal(None))) + else: + raise TypeError(f"Unexpected UnaryFilter operator type: {filter_pb.op}") + elif isinstance(filter_pb, Query_pb.FieldFilter): + field = Field.of(filter_pb.field.field_path) + value = decode_value(filter_pb.value, client) + if filter_pb.op == Query_pb.FieldFilter.Operator.LESS_THAN: + return And(field.exists(), field.less_than(value)) + elif filter_pb.op == Query_pb.FieldFilter.Operator.LESS_THAN_OR_EQUAL: + return And(field.exists(), field.less_than_or_equal(value)) + elif filter_pb.op == Query_pb.FieldFilter.Operator.GREATER_THAN: + return And(field.exists(), field.greater_than(value)) + elif filter_pb.op == Query_pb.FieldFilter.Operator.GREATER_THAN_OR_EQUAL: + return And(field.exists(), field.greater_than_or_equal(value)) + elif filter_pb.op == Query_pb.FieldFilter.Operator.EQUAL: + return And(field.exists(), field.equal(value)) + elif filter_pb.op == Query_pb.FieldFilter.Operator.NOT_EQUAL: + return And(field.exists(), field.not_equal(value)) + if filter_pb.op == Query_pb.FieldFilter.Operator.ARRAY_CONTAINS: + return And(field.exists(), field.array_contains(value)) + elif filter_pb.op == Query_pb.FieldFilter.Operator.ARRAY_CONTAINS_ANY: + return And(field.exists(), field.array_contains_any(value)) + elif filter_pb.op == Query_pb.FieldFilter.Operator.IN: + return And(field.exists(), field.equal_any(value)) + elif filter_pb.op == Query_pb.FieldFilter.Operator.NOT_IN: + return And(field.exists(), field.not_equal_any(value)) + else: + raise TypeError(f"Unexpected FieldFilter operator type: {filter_pb.op}") + elif isinstance(filter_pb, Query_pb.Filter): + # unwrap oneof + f = ( + filter_pb.composite_filter + or filter_pb.field_filter + or filter_pb.unary_filter + ) + return BooleanExpression._from_query_filter_pb(f, client) + else: + raise TypeError(f"Unexpected filter type: {type(filter_pb)}") + + +class Array(FunctionExpression): + """ + Creates an expression that creates a Firestore array value from an input list. + + Example: + >>> Array(["bar", Field.of("baz")]) + + Args: + elements: The input list to evaluate in the expression + """ + + def __init__(self, elements: list[Expression | CONSTANT_TYPE]): + if not isinstance(elements, list): + raise TypeError("Array must be constructed with a list") + converted_elements = [ + self._cast_to_expr_or_convert_to_constant(el) for el in elements + ] + super().__init__("array", converted_elements) + + def __repr__(self): + return f"Array({self.params})" + + +class Map(FunctionExpression): + """ + Creates an expression that creates a Firestore map value from an input dict. + + Example: + >>> Expression.map({"foo": "bar", "baz": Field.of("baz")}) + + Args: + elements: The input dict to evaluate in the expression + """ + + def __init__(self, elements: dict[str | Constant[str], Expression | CONSTANT_TYPE]): + element_list = [] + for k, v in elements.items(): + element_list.append(self._cast_to_expr_or_convert_to_constant(k)) + element_list.append(self._cast_to_expr_or_convert_to_constant(v)) + super().__init__("map", element_list) + + def __repr__(self): + formatted_params = [ + a.value if isinstance(a, Constant) else a for a in self.params + ] + d = {a: b for a, b in zip(formatted_params[::2], formatted_params[1::2])} + return f"Map({d})" + + +class And(BooleanExpression): + """ + Represents an expression that performs a logical 'AND' operation on multiple filter conditions. + + Example: + >>> # Check if the 'age' field is greater than 18 AND the 'city' field is "London" AND + >>> # the 'status' field is "active" + >>> And(Field.of("age").greater_than(18), Field.of("city").equal("London"), Field.of("status").equal("active")) + + Args: + *conditions: The filter conditions to 'AND' together. + """ + + def __init__(self, *conditions: "BooleanExpression"): + super().__init__("and", conditions, use_infix_repr=False) + + +class Not(BooleanExpression): + """ + Represents an expression that negates a filter condition. + + Example: + >>> # Find documents where the 'completed' field is NOT true + >>> Not(Field.of("completed").equal(True)) + + Args: + condition: The filter condition to negate. + """ + + def __init__(self, condition: BooleanExpression): + super().__init__("not", [condition], use_infix_repr=False) + + +class Or(BooleanExpression): + """ + Represents expression that performs a logical 'OR' operation on multiple filter conditions. + + Example: + >>> # Check if the 'age' field is greater than 18 OR the 'city' field is "London" OR + >>> # the 'status' field is "active" + >>> Or(Field.of("age").greater_than(18), Field.of("city").equal("London"), Field.of("status").equal("active")) + + Args: + *conditions: The filter conditions to 'OR' together. + """ + + def __init__(self, *conditions: "BooleanExpression"): + super().__init__("or", conditions, use_infix_repr=False) + + +class Xor(BooleanExpression): + """ + Represents an expression that performs a logical 'XOR' (exclusive OR) operation on multiple filter conditions. + + Example: + >>> # Check if only one of the conditions is true: 'age' greater than 18, 'city' is "London", + >>> # or 'status' is "active". + >>> Xor(Field.of("age").greater_than(18), Field.of("city").equal("London"), Field.of("status").equal("active")) + + Args: + *conditions: The filter conditions to 'XOR' together. + """ + + def __init__(self, conditions: Sequence["BooleanExpression"]): + super().__init__("xor", conditions, use_infix_repr=False) + + +class Conditional(BooleanExpression): + """ + Represents a conditional expression that evaluates to a 'then' expression if a condition is true + and an 'else' expression if the condition is false. + + Example: + >>> # If 'age' is greater than 18, return "Adult"; otherwise, return "Minor". + >>> Conditional(Field.of("age").greater_than(18), Constant.of("Adult"), Constant.of("Minor")); + + Args: + condition: The condition to evaluate. + then_expr: The expression to return if the condition is true. + else_expr: The expression to return if the condition is false + """ + + def __init__( + self, condition: BooleanExpression, then_expr: Expression, else_expr: Expression + ): + super().__init__( + "conditional", [condition, then_expr, else_expr], use_infix_repr=False + ) + + +class Count(AggregateFunction): + """ + Represents an aggregation that counts the number of stage inputs with valid evaluations of the + expression or field. + + Example: + >>> # Count the total number of products + >>> Field.of("productId").count().as_("totalProducts") + >>> Count(Field.of("productId")) + >>> Count().as_("count") + + Args: + expression: The expression or field to count. If None, counts all stage inputs. + """ + + def __init__(self, expression: Expression | None = None): + expression_list = [expression] if expression else [] + super().__init__("count", expression_list, use_infix_repr=bool(expression_list)) + + +class CurrentTimestamp(FunctionExpression): + """Creates an expression that returns the current timestamp + + Returns: + A new `Expression` representing the current timestamp. + """ + + def __init__(self): + super().__init__("current_timestamp", [], use_infix_repr=False) diff --git a/google/cloud/firestore_v1/pipeline_result.py b/google/cloud/firestore_v1/pipeline_result.py new file mode 100644 index 000000000..0496d0bfc --- /dev/null +++ b/google/cloud/firestore_v1/pipeline_result.py @@ -0,0 +1,302 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +.. warning:: + **Preview API**: Firestore Pipelines is currently in preview and is + subject to potential breaking changes in future releases. +""" + +from __future__ import annotations +from typing import ( + Any, + AsyncIterable, + AsyncIterator, + Iterable, + Iterator, + List, + Generic, + MutableMapping, + Type, + TypeVar, + TYPE_CHECKING, +) +from google.cloud.firestore_v1 import _helpers +from google.cloud.firestore_v1.field_path import get_nested_value +from google.cloud.firestore_v1.field_path import FieldPath +from google.cloud.firestore_v1.query_profile import ExplainStats +from google.cloud.firestore_v1.query_profile import QueryExplainError +from google.cloud.firestore_v1.types.firestore import ExecutePipelineRequest +from google.cloud.firestore_v1.types.document import Value + +if TYPE_CHECKING: # pragma: NO COVER + import datetime + from google.cloud.firestore_v1.async_client import AsyncClient + from google.cloud.firestore_v1.client import Client + from google.cloud.firestore_v1.base_client import BaseClient + from google.cloud.firestore_v1.async_transaction import AsyncTransaction + from google.cloud.firestore_v1.transaction import Transaction + from google.cloud.firestore_v1.base_document import BaseDocumentReference + from google.protobuf.timestamp_pb2 import Timestamp + from google.cloud.firestore_v1.types.firestore import ExecutePipelineResponse + from google.cloud.firestore_v1.types.document import Value as ValueProto + from google.cloud.firestore_v1.vector import Vector + from google.cloud.firestore_v1.async_pipeline import AsyncPipeline + from google.cloud.firestore_v1.base_pipeline import _BasePipeline + from google.cloud.firestore_v1.pipeline import Pipeline + from google.cloud.firestore_v1.pipeline_expressions import Constant + from google.cloud.firestore_v1.query_profile import PipelineExplainOptions + + +class PipelineResult: + """ + Contains data read from a Firestore Pipeline. The data can be extracted with + the `data()` or `get()` methods. + + If the PipelineResult represents a non-document result `ref` may be `None`. + """ + + def __init__( + self, + client: BaseClient, + fields_pb: MutableMapping[str, ValueProto], + ref: BaseDocumentReference | None = None, + execution_time: Timestamp | None = None, + create_time: Timestamp | None = None, + update_time: Timestamp | None = None, + ): + """ + PipelineResult should be returned from `pipeline.execute()`, not constructed manually. + + Args: + client: The Firestore client instance. + fields_pb: A map of field names to their protobuf Value representations. + ref: The DocumentReference or AsyncDocumentReference if this result corresponds to a document. + execution_time: The time at which the pipeline execution producing this result occurred. + create_time: The creation time of the document, if applicable. + update_time: The last update time of the document, if applicable. + """ + self._client = client + self._fields_pb = fields_pb + self._ref = ref + self._execution_time = execution_time + self._create_time = create_time + self._update_time = update_time + + def __repr__(self): + return f"{type(self).__name__}(data={self.data()})" + + @property + def ref(self) -> BaseDocumentReference | None: + """ + The `BaseDocumentReference` if this result represents a document, else `None`. + """ + return self._ref + + @property + def id(self) -> str | None: + """The ID of the document if this result represents a document, else `None`.""" + return self._ref.id if self._ref else None + + @property + def create_time(self) -> Timestamp | None: + """The creation time of the document. `None` if not applicable.""" + return self._create_time + + @property + def update_time(self) -> Timestamp | None: + """The last update time of the document. `None` if not applicable.""" + return self._update_time + + @property + def execution_time(self) -> Timestamp: + """ + The time at which the pipeline producing this result was executed. + + Raise: + ValueError: if not set + """ + if self._execution_time is None: + raise ValueError("'execution_time' is expected to exist, but it is None.") + return self._execution_time + + def __eq__(self, other: object) -> bool: + """ + Compares this `PipelineResult` to another object for equality. + + Two `PipelineResult` instances are considered equal if their document + references (if any) are equal and their underlying field data + (protobuf representation) is identical. + """ + if not isinstance(other, PipelineResult): + return NotImplemented + return (self._ref == other._ref) and (self._fields_pb == other._fields_pb) + + def data(self) -> dict | "Vector" | None: + """ + Retrieves all fields in the result. + + Returns: + The data in dictionary format, or `None` if the document doesn't exist. + """ + if self._fields_pb is None: + return None + + return _helpers.decode_dict(self._fields_pb, self._client) + + def get(self, field_path: str | FieldPath) -> Any: + """ + Retrieves the field specified by `field_path`. + + Args: + field_path: The field path (e.g. 'foo' or 'foo.bar') to a specific field. + + Returns: + The data at the specified field location, decoded to Python types. + """ + str_path = ( + field_path if isinstance(field_path, str) else field_path.to_api_repr() + ) + value = get_nested_value(str_path, self._fields_pb) + return _helpers.decode_value(value, self._client) + + +T = TypeVar("T", bound=PipelineResult) + + +class _PipelineResultContainer(Generic[T]): + """Base class to hold shared attributes for PipelineSnapshot and PipelineStream""" + + def __init__( + self, + return_type: Type[T], + pipeline: Pipeline | AsyncPipeline, + transaction: Transaction | AsyncTransaction | None, + read_time: datetime.datetime | None, + explain_options: PipelineExplainOptions | None, + additional_options: dict[str, Constant | Value], + ): + # public + self.transaction = transaction + self.pipeline: _BasePipeline = pipeline + self.execution_time: Timestamp | None = None + # private + self._client: Client | AsyncClient = pipeline._client + self._started: bool = False + self._read_time = read_time + self._explain_stats: ExplainStats | None = None + self._explain_options: PipelineExplainOptions | None = explain_options + self._return_type = return_type + self._additonal_options = { + k: v if isinstance(v, Value) else v._to_pb() + for k, v in additional_options.items() + } + + @property + def explain_stats(self) -> ExplainStats: + if self._explain_stats is not None: + return self._explain_stats + elif self._explain_options is None: + raise QueryExplainError("explain_options not set on query.") + elif not self._started: + raise QueryExplainError( + "explain_stats not available until query is complete" + ) + else: + raise QueryExplainError("explain_stats not found") + + def _build_request(self) -> ExecutePipelineRequest: + """ + shared logic for creating an ExecutePipelineRequest + """ + database_name = ( + f"projects/{self._client.project}/databases/{self._client._database}" + ) + transaction_id = ( + _helpers.get_transaction_id(self.transaction, read_operation=False) + if self.transaction is not None + else None + ) + options = {} + if self._explain_options: + options["explain_options"] = self._explain_options._to_value() + if self._additonal_options: + options.update(self._additonal_options) + request = ExecutePipelineRequest( + database=database_name, + transaction=transaction_id, + structured_pipeline=self.pipeline._to_pb(**options), + read_time=self._read_time, + ) + return request + + def _process_response(self, response: ExecutePipelineResponse) -> Iterable[T]: + """Shared logic for processing an individual response from a stream""" + if response.explain_stats: + self._explain_stats = ExplainStats(response.explain_stats) + execution_time = response._pb.execution_time + if execution_time and not self.execution_time: + self.execution_time = execution_time + for doc in response.results: + ref = self._client.document(doc.name) if doc.name else None + yield self._return_type( + self._client, + doc.fields, + ref, + execution_time, + doc._pb.create_time if doc.create_time else None, + doc._pb.update_time if doc.update_time else None, + ) + + +class PipelineSnapshot(_PipelineResultContainer[T], List[T]): + """ + A list type that holds the result of a pipeline.execute() operation, along with related metadata + """ + + def __init__(self, results_list: List[T], source: _PipelineResultContainer[T]): + self.__dict__.update(source.__dict__.copy()) + list.__init__(self, results_list) + # snapshots are always complete + self._started = True + + +class PipelineStream(_PipelineResultContainer[T], Iterable[T]): + """ + An iterable stream representing the result of a pipeline.stream() operation, along with related metadata + """ + + def __iter__(self) -> Iterator[T]: + if self._started: + raise RuntimeError(f"{self.__class__.__name__} can only be iterated once") + self._started = True + request = self._build_request() + stream = self._client._firestore_api.execute_pipeline(request) + for response in stream: + yield from self._process_response(response) + + +class AsyncPipelineStream(_PipelineResultContainer[T], AsyncIterable[T]): + """ + An iterable stream representing the result of an async pipeline.stream() operation, along with related metadata + """ + + async def __aiter__(self) -> AsyncIterator[T]: + if self._started: + raise RuntimeError(f"{self.__class__.__name__} can only be iterated once") + self._started = True + request = self._build_request() + stream = await self._client._firestore_api.execute_pipeline(request) + async for response in stream: + for result in self._process_response(response): + yield result diff --git a/google/cloud/firestore_v1/pipeline_source.py b/google/cloud/firestore_v1/pipeline_source.py new file mode 100644 index 000000000..8f3c0a626 --- /dev/null +++ b/google/cloud/firestore_v1/pipeline_source.py @@ -0,0 +1,112 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +.. warning:: + **Preview API**: Firestore Pipelines is currently in preview and is + subject to potential breaking changes in future releases. +""" + +from __future__ import annotations +from typing import Generic, TypeVar, TYPE_CHECKING +from google.cloud.firestore_v1 import pipeline_stages as stages +from google.cloud.firestore_v1.base_pipeline import _BasePipeline +from google.cloud.firestore_v1._helpers import DOCUMENT_PATH_DELIMITER + +if TYPE_CHECKING: # pragma: NO COVER + from google.cloud.firestore_v1.client import Client + from google.cloud.firestore_v1.async_client import AsyncClient + from google.cloud.firestore_v1.base_document import BaseDocumentReference + from google.cloud.firestore_v1.base_query import BaseQuery + from google.cloud.firestore_v1.base_aggregation import BaseAggregationQuery + from google.cloud.firestore_v1.base_collection import BaseCollectionReference + + +PipelineType = TypeVar("PipelineType", bound=_BasePipeline) + + +class PipelineSource(Generic[PipelineType]): + """ + A factory for creating Pipeline instances, which provide a framework for building data + transformation and query pipelines for Firestore. + + Not meant to be instantiated directly. Instead, start by calling client.pipeline() + to obtain an instance of PipelineSource. From there, you can use the provided + methods to specify the data source for your pipeline. + """ + + def __init__(self, client: Client | AsyncClient): + self.client = client + + def _create_pipeline(self, source_stage): + return self.client._pipeline_cls._create_with_stages(self.client, source_stage) + + def create_from( + self, query: "BaseQuery" | "BaseAggregationQuery" | "BaseCollectionReference" + ) -> PipelineType: + """ + Create a pipeline from an existing query + + Queries containing a `cursor` or `limit_to_last` are not currently supported + + Args: + query: the query to build the pipeline off of + Raises: + - NotImplementedError: raised if the query contains a `cursor` or `limit_to_last` + Returns: + a new pipeline instance representing the query + """ + return query._build_pipeline(self) + + def collection(self, path: str | tuple[str]) -> PipelineType: + """ + Creates a new Pipeline that operates on a specified Firestore collection. + + Args: + path: The path to the Firestore collection (e.g., "users"). Can either be: + * A single ``/``-delimited path to a collection + * A tuple of collection path segment + Returns: + a new pipeline instance targeting the specified collection + """ + if isinstance(path, tuple): + path = DOCUMENT_PATH_DELIMITER.join(path) + return self._create_pipeline(stages.Collection(path)) + + def collection_group(self, collection_id: str) -> PipelineType: + """ + Creates a new Pipeline that that operates on all documents in a collection group. + Args: + collection_id: The ID of the collection group + Returns: + a new pipeline instance targeting the specified collection group + """ + return self._create_pipeline(stages.CollectionGroup(collection_id)) + + def database(self) -> PipelineType: + """ + Creates a new Pipeline that operates on all documents in the Firestore database. + Returns: + a new pipeline instance targeting the specified collection + """ + return self._create_pipeline(stages.Database()) + + def documents(self, *docs: "BaseDocumentReference") -> PipelineType: + """ + Creates a new Pipeline that operates on a specific set of Firestore documents. + Args: + docs: The DocumentReference instances representing the documents to include in the pipeline. + Returns: + a new pipeline instance targeting the specified documents + """ + return self._create_pipeline(stages.Documents.of(*docs)) diff --git a/google/cloud/firestore_v1/pipeline_stages.py b/google/cloud/firestore_v1/pipeline_stages.py new file mode 100644 index 000000000..18aa27044 --- /dev/null +++ b/google/cloud/firestore_v1/pipeline_stages.py @@ -0,0 +1,475 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +.. warning:: + **Preview API**: Firestore Pipelines is currently in preview and is + subject to potential breaking changes in future releases. +""" + +from __future__ import annotations +from typing import Optional, Sequence, TYPE_CHECKING +from abc import ABC +from abc import abstractmethod +from enum import Enum + +from google.cloud.firestore_v1.types.document import Pipeline as Pipeline_pb +from google.cloud.firestore_v1.types.document import Value +from google.cloud.firestore_v1.vector import Vector +from google.cloud.firestore_v1.base_vector_query import DistanceMeasure +from google.cloud.firestore_v1.pipeline_expressions import ( + AggregateFunction, + Expression, + AliasedExpression, + Field, + BooleanExpression, + Selectable, + Ordering, +) +from google.cloud.firestore_v1._helpers import encode_value + +if TYPE_CHECKING: # pragma: NO COVER + from google.cloud.firestore_v1.base_pipeline import _BasePipeline + from google.cloud.firestore_v1.base_document import BaseDocumentReference + + +class FindNearestOptions: + """Options for configuring the `FindNearest` pipeline stage. + + Attributes: + limit (Optional[int]): The maximum number of nearest neighbors to return. + distance_field (Optional[Field]): An optional field to store the calculated + distance in the output documents. + """ + + def __init__( + self, + limit: Optional[int] = None, + distance_field: Optional[Field] = None, + ): + self.limit = limit + self.distance_field = distance_field + + def __repr__(self): + args = [] + if self.limit is not None: + args.append(f"limit={self.limit}") + if self.distance_field is not None: + args.append(f"distance_field={self.distance_field}") + return f"{self.__class__.__name__}({', '.join(args)})" + + +class SampleOptions: + """Options for the 'sample' pipeline stage.""" + + class Mode(Enum): + DOCUMENTS = "documents" + PERCENT = "percent" + + def __init__(self, value: int | float, mode: Mode | str): + self.value = value + self.mode = SampleOptions.Mode[mode.upper()] if isinstance(mode, str) else mode + + def __repr__(self): + if self.mode == SampleOptions.Mode.DOCUMENTS: + mode_str = "doc_limit" + else: + mode_str = "percentage" + return f"SampleOptions.{mode_str}({self.value})" + + @staticmethod + def doc_limit(value: int): + """ + Sample a set number of documents + + Args: + value: number of documents to sample + """ + return SampleOptions(value, mode=SampleOptions.Mode.DOCUMENTS) + + @staticmethod + def percentage(value: float): + """ + Sample a percentage of documents + + Args: + value: percentage of documents to return + """ + return SampleOptions(value, mode=SampleOptions.Mode.PERCENT) + + +class UnnestOptions: + """Options for configuring the `Unnest` pipeline stage. + + Attributes: + index_field (str): The name of the field to add to each output document, + storing the original 0-based index of the element within the array. + """ + + def __init__(self, index_field: Field | str): + self.index_field = ( + index_field if isinstance(index_field, Field) else Field.of(index_field) + ) + + def __repr__(self): + return f"{self.__class__.__name__}(index_field={self.index_field.path!r})" + + +class Stage(ABC): + """Base class for all pipeline stages. + + Each stage represents a specific operation (e.g., filtering, sorting, + transforming) within a Firestore pipeline. Subclasses define the specific + arguments and behavior for each operation. + """ + + def __init__(self, custom_name: Optional[str] = None): + self.name = custom_name or type(self).__name__.lower() + + def _to_pb(self) -> Pipeline_pb.Stage: + return Pipeline_pb.Stage( + name=self.name, args=self._pb_args(), options=self._pb_options() + ) + + @abstractmethod + def _pb_args(self) -> list[Value]: + """Return Ordered list of arguments the given stage expects""" + raise NotImplementedError + + def _pb_options(self) -> dict[str, Value]: + """Return optional named arguments that certain functions may support.""" + return {} + + def __repr__(self): + items = ("%s=%r" % (k, v) for k, v in self.__dict__.items() if k != "name") + return f"{self.__class__.__name__}({', '.join(items)})" + + +class AddFields(Stage): + """Adds new fields to outputs from previous stages.""" + + def __init__(self, *fields: Selectable): + super().__init__("add_fields") + self.fields = list(fields) + + def _pb_args(self): + return [Selectable._to_value(self.fields)] + + +class Aggregate(Stage): + """Performs aggregation operations, optionally grouped.""" + + def __init__( + self, + *args: AliasedExpression[AggregateFunction], + accumulators: Sequence[AliasedExpression[AggregateFunction]] = (), + groups: Sequence[str | Selectable] = (), + ): + super().__init__() + self.groups: list[Selectable] = [ + Field(f) if isinstance(f, str) else f for f in groups + ] + if args and accumulators: + raise ValueError( + "Aggregate stage contains both positional and keyword accumulators" + ) + self.accumulators = args or accumulators + + def _pb_args(self): + return [ + Selectable._to_value(self.accumulators), + Selectable._to_value(self.groups), + ] + + def __repr__(self): + accumulator_str = ", ".join(repr(v) for v in self.accumulators) + group_str = "" + if self.groups: + if self.accumulators: + group_str = ", " + group_str += f"groups={self.groups}" + return f"{self.__class__.__name__}({accumulator_str}{group_str})" + + +class Collection(Stage): + """Specifies a collection as the initial data source.""" + + def __init__(self, path: str): + super().__init__() + if not path.startswith("/"): + path = f"/{path}" + self.path = path + + def _pb_args(self): + return [Value(reference_value=self.path)] + + +class CollectionGroup(Stage): + """Specifies a collection group as the initial data source.""" + + def __init__(self, collection_id: str): + super().__init__("collection_group") + self.collection_id = collection_id + + def _pb_args(self): + return [Value(reference_value=""), Value(string_value=self.collection_id)] + + +class Database(Stage): + """Specifies the default database as the initial data source.""" + + def __init__(self): + super().__init__() + + def _pb_args(self): + return [] + + +class Distinct(Stage): + """Returns documents with distinct combinations of specified field values.""" + + def __init__(self, *fields: str | Selectable): + super().__init__() + self.fields: list[Selectable] = [ + Field(f) if isinstance(f, str) else f for f in fields + ] + + def _pb_args(self) -> list[Value]: + return [Selectable._to_value(self.fields)] + + +class Documents(Stage): + """Specifies specific documents as the initial data source.""" + + def __init__(self, *paths: str): + super().__init__() + self.paths = paths + + def __repr__(self): + return f"{self.__class__.__name__}({', '.join([repr(p) for p in self.paths])})" + + @staticmethod + def of(*documents: "BaseDocumentReference") -> "Documents": + doc_paths = ["/" + doc.path for doc in documents] + return Documents(*doc_paths) + + def _pb_args(self): + return [Value(reference_value=path) for path in self.paths] + + +class FindNearest(Stage): + """Performs vector distance (similarity) search.""" + + def __init__( + self, + field: str | Expression, + vector: Sequence[float] | Vector, + distance_measure: "DistanceMeasure" | str, + options: Optional["FindNearestOptions"] = None, + ): + super().__init__("find_nearest") + self.field: Expression = Field(field) if isinstance(field, str) else field + self.vector: Vector = vector if isinstance(vector, Vector) else Vector(vector) + self.distance_measure = ( + distance_measure + if isinstance(distance_measure, DistanceMeasure) + else DistanceMeasure[distance_measure.upper()] + ) + self.options = options or FindNearestOptions() + + def _pb_args(self): + return [ + self.field._to_pb(), + encode_value(self.vector), + Value(string_value=self.distance_measure.name.lower()), + ] + + def _pb_options(self) -> dict[str, Value]: + options = {} + if self.options and self.options.limit is not None: + options["limit"] = Value(integer_value=self.options.limit) + if self.options and self.options.distance_field is not None: + options["distance_field"] = self.options.distance_field._to_pb() + return options + + +class RawStage(Stage): + """Represents a generic, named stage with parameters.""" + + def __init__( + self, + name: str, + *params: Expression | Value, + options: dict[str, Expression | Value] = {}, + ): + super().__init__(name) + self.params: list[Value] = [ + p._to_pb() if isinstance(p, Expression) else p for p in params + ] + self.options: dict[str, Value] = { + k: v._to_pb() if isinstance(v, Expression) else v + for k, v in options.items() + } + + def _pb_args(self): + return self.params + + def _pb_options(self): + return self.options + + def __repr__(self): + return f"{self.__class__.__name__}(name='{self.name}')" + + +class Limit(Stage): + """Limits the maximum number of documents returned.""" + + def __init__(self, limit: int): + super().__init__() + self.limit = limit + + def _pb_args(self): + return [Value(integer_value=self.limit)] + + +class Offset(Stage): + """Skips a specified number of documents.""" + + def __init__(self, offset: int): + super().__init__() + self.offset = offset + + def _pb_args(self): + return [Value(integer_value=self.offset)] + + +class RemoveFields(Stage): + """Removes specified fields from outputs.""" + + def __init__(self, *fields: str | Field): + super().__init__("remove_fields") + self.fields = [Field(f) if isinstance(f, str) else f for f in fields] + + def __repr__(self): + return f"{self.__class__.__name__}({', '.join(repr(f) for f in self.fields)})" + + def _pb_args(self) -> list[Value]: + return [f._to_pb() for f in self.fields] + + +class ReplaceWith(Stage): + """Replaces the document content with the value of a specified field.""" + + def __init__(self, field: Selectable): + super().__init__("replace_with") + self.field = Field(field) if isinstance(field, str) else field + + def _pb_args(self): + return [self.field._to_pb(), Value(string_value="full_replace")] + + +class Sample(Stage): + """Performs pseudo-random sampling of documents.""" + + def __init__(self, limit_or_options: int | SampleOptions): + super().__init__() + if isinstance(limit_or_options, int): + options = SampleOptions.doc_limit(limit_or_options) + else: + options = limit_or_options + self.options: SampleOptions = options + + def _pb_args(self): + if self.options.mode == SampleOptions.Mode.DOCUMENTS: + return [ + Value(integer_value=self.options.value), + Value(string_value="documents"), + ] + else: + return [ + Value(double_value=self.options.value), + Value(string_value="percent"), + ] + + +class Select(Stage): + """Selects or creates a set of fields.""" + + def __init__(self, *selections: str | Selectable): + super().__init__() + self.projections = [Field(s) if isinstance(s, str) else s for s in selections] + + def _pb_args(self) -> list[Value]: + return [Selectable._value_from_selectables(*self.projections)] + + +class Sort(Stage): + """Sorts documents based on specified criteria.""" + + def __init__(self, *orders: "Ordering"): + super().__init__() + self.orders = list(orders) + + def _pb_args(self): + return [o._to_pb() for o in self.orders] + + +class Union(Stage): + """Performs a union of documents from two pipelines.""" + + def __init__(self, other: _BasePipeline): + super().__init__() + self.other = other + + def _pb_args(self): + return [Value(pipeline_value=self.other._to_pb().pipeline)] + + +class Unnest(Stage): + """Produces a document for each element in an array field.""" + + def __init__( + self, + field: Selectable | str, + alias: Field | str | None = None, + options: UnnestOptions | None = None, + ): + super().__init__() + self.field: Selectable = Field(field) if isinstance(field, str) else field + if alias is None: + self.alias = self.field + elif isinstance(alias, str): + self.alias = Field(alias) + else: + self.alias = alias + self.options = options + + def _pb_args(self): + return [self.field._to_pb(), self.alias._to_pb()] + + def _pb_options(self): + options = {} + if self.options is not None: + options["index_field"] = self.options.index_field._to_pb() + return options + + +class Where(Stage): + """Filters documents based on a specified condition.""" + + def __init__(self, condition: BooleanExpression): + super().__init__() + self.condition = condition + + def _pb_args(self): + return [self.condition._to_pb()] diff --git a/google/cloud/firestore_v1/query_profile.py b/google/cloud/firestore_v1/query_profile.py index 6925f83ff..5e8491fc6 100644 --- a/google/cloud/firestore_v1/query_profile.py +++ b/google/cloud/firestore_v1/query_profile.py @@ -19,6 +19,12 @@ from dataclasses import dataclass from google.protobuf.json_format import MessageToDict +from google.cloud.firestore_v1.types.document import MapValue +from google.cloud.firestore_v1.types.document import Value +from google.cloud.firestore_v1.types.explain_stats import ( + ExplainStats as ExplainStats_pb, +) +from google.protobuf.wrappers_pb2 import StringValue @dataclass(frozen=True) @@ -42,6 +48,32 @@ def _to_dict(self): return {"analyze": self.analyze} +@dataclass(frozen=True) +class PipelineExplainOptions: + """ + Explain options for pipeline queries. + + Set on a pipeline.execution() or pipeline.stream() call, to provide + explain_stats in the pipeline output + + :type mode: str + :param mode: Optional. The mode of operation for this explain query. + When set to 'analyze', the query will be executed and return the full + query results along with execution statistics. + + :type output_format: str | None + :param output_format: Optional. The format in which to return the explain + stats. + """ + + mode: str = "analyze" + + def _to_value(self): + out_dict = {"mode": Value(string_value=self.mode)} + value_pb = MapValue(fields=out_dict) + return Value(map_value=value_pb) + + @dataclass(frozen=True) class PlanSummary: """ @@ -143,3 +175,54 @@ class QueryExplainError(Exception): """ pass + + +class ExplainStats: + """ + Contains query profiling statistics for a pipeline query. + + This class is not meant to be instantiated directly by the user. Instead, an + instance of `ExplainStats` may be returned by pipeline execution methods + when `explain_options` are provided. + + It provides methods to access the explain statistics in different formats. + """ + + def __init__(self, stats_pb: ExplainStats_pb): + """ + Args: + stats_pb (ExplainStats_pb): The raw protobuf message for explain stats. + """ + self._stats_pb = stats_pb + + def get_text(self) -> str: + """ + Returns the explain stats as a string. + + This method is suitable for explain formats that have a text-based output, + such as 'text' or 'json'. + + Returns: + str: The string representation of the explain stats. + + Raises: + QueryExplainError: If the explain stats payload from the backend is not + a string. This can happen if a non-text output format was requested. + """ + pb_data = self._stats_pb._pb.data + content = StringValue() + if pb_data.Unpack(content): + return content.value + raise QueryExplainError( + "Unable to decode explain stats. Did you request an output format that returns a string value, such as 'text' or 'json'?" + ) + + def get_raw(self) -> ExplainStats_pb: + """ + Returns the explain stats in an encoded proto format, as returned from the Firestore backend. + The caller is responsible for unpacking this proto message. + + Returns: + google.cloud.firestore_v1.types.explain_stats.ExplainStats: the proto from the backend + """ + return self._stats_pb diff --git a/noxfile.py b/noxfile.py index 3e8b80770..a1bce0822 100644 --- a/noxfile.py +++ b/noxfile.py @@ -75,6 +75,7 @@ SYSTEM_TEST_EXTERNAL_DEPENDENCIES: List[str] = [ "pytest-asyncio==0.21.2", "six", + "pyyaml", ] SYSTEM_TEST_LOCAL_DEPENDENCIES: List[str] = [] SYSTEM_TEST_DEPENDENCIES: List[str] = [] diff --git a/tests/system/pipeline_e2e/aggregates.yaml b/tests/system/pipeline_e2e/aggregates.yaml new file mode 100644 index 000000000..64a42698b --- /dev/null +++ b/tests/system/pipeline_e2e/aggregates.yaml @@ -0,0 +1,284 @@ +tests: + - description: "testAggregates - count" + pipeline: + - Collection: books + - Aggregate: + - AliasedExpression: + - FunctionExpression.count: + - Field: rating + - "count" + assert_results: + - count: 10 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + count: + functionValue: + name: count + args: + - fieldReferenceValue: rating + - mapValue: {} + name: aggregate + - description: "testAggregates - count_if" + pipeline: + - Collection: books + - Aggregate: + - AliasedExpression: + - FunctionExpression.count_if: + - FunctionExpression.greater_than: + - Field: rating + - Constant: 4.2 + - "count_if_rating_gt_4_2" + assert_results: + - count_if_rating_gt_4_2: 5 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + count_if_rating_gt_4_2: + functionValue: + name: count_if + args: + - functionValue: + name: greater_than + args: + - fieldReferenceValue: rating + - doubleValue: 4.2 + - mapValue: {} + name: aggregate + - description: "testAggregates - count_distinct" + pipeline: + - Collection: books + - Aggregate: + - AliasedExpression: + - FunctionExpression.count_distinct: + - Field: genre + - "distinct_genres" + assert_results: + - distinct_genres: 8 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + distinct_genres: + functionValue: + name: count_distinct + args: + - fieldReferenceValue: genre + - mapValue: {} + name: aggregate + - description: "testAggregates - avg, count, max" + pipeline: + - Collection: books + - Where: + - FunctionExpression.equal: + - Field: genre + - Constant: Science Fiction + - Aggregate: + - AliasedExpression: + - FunctionExpression.count: + - Field: rating + - "count" + - AliasedExpression: + - FunctionExpression.average: + - Field: rating + - "avg_rating" + - AliasedExpression: + - FunctionExpression.maximum: + - Field: rating + - "max_rating" + assert_results: + - count: 2 + avg_rating: 4.4 + max_rating: 4.6 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: genre + - stringValue: Science Fiction + name: equal + name: where + - args: + - mapValue: + fields: + avg_rating: + functionValue: + args: + - fieldReferenceValue: rating + name: average + count: + functionValue: + name: count + args: + - fieldReferenceValue: rating + max_rating: + functionValue: + args: + - fieldReferenceValue: rating + name: maximum + - mapValue: {} + name: aggregate + - description: testGroupBysWithoutAccumulators + pipeline: + - Collection: books + - Where: + - FunctionExpression.less_than: + - Field: published + - Constant: 1900 + - Aggregate: + accumulators: [] + groups: [genre] + assert_error: ".* requires at least one accumulator" + - description: testGroupBysAndAggregate + pipeline: + - Collection: books + - Where: + - FunctionExpression.less_than: + - Field: published + - Constant: 1984 + - Aggregate: + accumulators: + - AliasedExpression: + - FunctionExpression.average: + - Field: rating + - "avg_rating" + groups: [genre] + - Where: + - FunctionExpression.greater_than: + - Field: avg_rating + - Constant: 4.3 + - Sort: + - Ordering: + - Field: avg_rating + - ASCENDING + assert_results: + - avg_rating: 4.4 + genre: Science Fiction + - avg_rating: 4.5 + genre: Romance + - avg_rating: 4.7 + genre: Fantasy + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: published + - integerValue: '1984' + name: less_than + name: where + - args: + - mapValue: + fields: + avg_rating: + functionValue: + args: + - fieldReferenceValue: rating + name: average + - mapValue: + fields: + genre: + fieldReferenceValue: genre + name: aggregate + - args: + - functionValue: + args: + - fieldReferenceValue: avg_rating + - doubleValue: 4.3 + name: greater_than + name: where + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: avg_rating + name: sort + - description: testMinMax + pipeline: + - Collection: books + - Aggregate: + - AliasedExpression: + - FunctionExpression.count: + - Field: rating + - "count" + - AliasedExpression: + - FunctionExpression.maximum: + - Field: rating + - "max_rating" + - AliasedExpression: + - FunctionExpression.minimum: + - Field: published + - "min_published" + assert_results: + - count: 10 + max_rating: 4.7 + min_published: 1813 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + count: + functionValue: + args: + - fieldReferenceValue: rating + name: count + max_rating: + functionValue: + args: + - fieldReferenceValue: rating + name: maximum + min_published: + functionValue: + args: + - fieldReferenceValue: published + name: minimum + - mapValue: {} + name: aggregate + - description: testSum + pipeline: + - Collection: books + - Where: + - FunctionExpression.equal: + - Field: genre + - Constant: Science Fiction + - Aggregate: + - AliasedExpression: + - FunctionExpression.sum: + - Field: rating + - "total_rating" + assert_results: + - total_rating: 8.8 + diff --git a/tests/system/pipeline_e2e/array.yaml b/tests/system/pipeline_e2e/array.yaml new file mode 100644 index 000000000..f82f1cbc1 --- /dev/null +++ b/tests/system/pipeline_e2e/array.yaml @@ -0,0 +1,464 @@ +tests: + - description: testArrayContains + pipeline: + - Collection: books + - Where: + - FunctionExpression.array_contains: + - Field: tags + - Constant: comedy + assert_results: + - title: The Hitchhiker's Guide to the Galaxy + author: Douglas Adams + awards: + hugo: true + nebula: false + genre: Science Fiction + published: 1979 + rating: 4.2 + tags: ["comedy", "space", "adventure"] + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: tags + - stringValue: comedy + name: array_contains + name: where + - description: testArrayContainsAny + pipeline: + - Collection: books + - Where: + - FunctionExpression.array_contains_any: + - Field: tags + - - Constant: comedy + - Constant: classic + - Select: + - title + - Sort: + - Ordering: + - Field: title + - ASCENDING + assert_results: + - title: Pride and Prejudice + - title: The Hitchhiker's Guide to the Galaxy + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: tags + - functionValue: + args: + - stringValue: comedy + - stringValue: classic + name: array + name: array_contains_any + name: where + - args: + - mapValue: + fields: + title: + fieldReferenceValue: title + name: select + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: title + name: sort + - description: testArrayContainsAll + pipeline: + - Collection: books + - Where: + - FunctionExpression.array_contains_all: + - Field: tags + - - Constant: adventure + - Constant: magic + - Select: + - title + assert_results: + - title: The Lord of the Rings + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: tags + - functionValue: + args: + - stringValue: adventure + - stringValue: magic + name: array + name: array_contains_all + name: where + - args: + - mapValue: + fields: + title: + fieldReferenceValue: title + name: select + - description: testArrayLength + pipeline: + - Collection: books + - Select: + - AliasedExpression: + - FunctionExpression.array_length: + - Field: tags + - "tagsCount" + - Where: + - FunctionExpression.equal: + - Field: tagsCount + - Constant: 3 + assert_results: # All documents have 3 tags + - tagsCount: 3 + - tagsCount: 3 + - tagsCount: 3 + - tagsCount: 3 + - tagsCount: 3 + - tagsCount: 3 + - tagsCount: 3 + - tagsCount: 3 + - tagsCount: 3 + - tagsCount: 3 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + tagsCount: + functionValue: + args: + - fieldReferenceValue: tags + name: array_length + name: select + - args: + - functionValue: + args: + - fieldReferenceValue: tagsCount + - integerValue: '3' + name: equal + name: where + - description: testArrayReverse + pipeline: + - Collection: books + - Where: + - FunctionExpression.equal: + - Field: title + - Constant: "The Hitchhiker's Guide to the Galaxy" + - Select: + - AliasedExpression: + - FunctionExpression.array_reverse: + - Field: tags + - "reversedTags" + assert_results: + - reversedTags: + - adventure + - space + - comedy + - description: testArrayConcat + pipeline: + - Collection: books + - Where: + - FunctionExpression.equal: + - Field: title + - Constant: "The Hitchhiker's Guide to the Galaxy" + - Select: + - AliasedExpression: + - FunctionExpression.array_concat: + - Field: tags + - ["new_tag", "another_tag"] + - "concatenatedTags" + assert_results: + - concatenatedTags: + - comedy + - space + - adventure + - new_tag + - another_tag + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: "The Hitchhiker's Guide to the Galaxy" + name: equal + name: where + - args: + - mapValue: + fields: + concatenatedTags: + functionValue: + args: + - fieldReferenceValue: tags + - functionValue: + args: + - stringValue: "new_tag" + - stringValue: "another_tag" + name: array + name: array_concat + name: select + - description: testArrayConcatMultiple + pipeline: + - Collection: books + - Where: + - FunctionExpression.equal: + - Field: title + - Constant: "Dune" + - Select: + - AliasedExpression: + - FunctionExpression.array_concat: + - Field: tags + - ["sci-fi"] + - ["classic", "epic"] + - "concatenatedTags" + assert_results: + - concatenatedTags: + - politics + - desert + - ecology + - sci-fi + - classic + - epic + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: "Dune" + name: equal + name: where + - args: + - mapValue: + fields: + concatenatedTags: + functionValue: + args: + - fieldReferenceValue: tags + - functionValue: + args: + - stringValue: "sci-fi" + name: array + - functionValue: + args: + - stringValue: "classic" + - stringValue: "epic" + name: array + name: array_concat + name: select + - description: testArrayContainsAnyWithField + pipeline: + - Collection: books + - AddFields: + - AliasedExpression: + - FunctionExpression.array_concat: + - Field: tags + - Array: ["Dystopian"] + - "new_tags" + - Where: + - FunctionExpression.array_contains_any: + - Field: new_tags + - - Constant: non_existent_tag + - Field: genre + - Select: + - title + - genre + - Sort: + - Ordering: + - Field: title + - ASCENDING + assert_results: + - title: "1984" + genre: "Dystopian" + - title: "The Handmaid's Tale" + genre: "Dystopian" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + new_tags: + functionValue: + args: + - fieldReferenceValue: tags + - functionValue: + args: + - stringValue: "Dystopian" + name: array + name: array_concat + name: add_fields + - args: + - functionValue: + args: + - fieldReferenceValue: new_tags + - functionValue: + args: + - stringValue: "non_existent_tag" + - fieldReferenceValue: genre + name: array + name: array_contains_any + name: where + - args: + - mapValue: + fields: + title: + fieldReferenceValue: title + genre: + fieldReferenceValue: genre + name: select + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: title + name: sort + - description: testArrayConcatLiterals + pipeline: + - Collection: books + - Limit: 1 + - Select: + - AliasedExpression: + - FunctionExpression.array_concat: + - Array: [1, 2, 3] + - Array: [4, 5] + - "concatenated" + assert_results: + - concatenated: [1, 2, 3, 4, 5] + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - integerValue: '1' + name: limit + - args: + - mapValue: + fields: + concatenated: + functionValue: + args: + - functionValue: + args: + - integerValue: '1' + - integerValue: '2' + - integerValue: '3' + name: array + - functionValue: + args: + - integerValue: '4' + - integerValue: '5' + name: array + name: array_concat + name: select + - description: testArrayGet + pipeline: + - Collection: books + - Where: + - FunctionExpression.equal: + - Field: title + - Constant: "The Hitchhiker's Guide to the Galaxy" + - Select: + - AliasedExpression: + - FunctionExpression.array_get: + - Field: tags + - Constant: 0 + - "firstTag" + assert_results: + - firstTag: "comedy" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: "The Hitchhiker's Guide to the Galaxy" + name: equal + name: where + - args: + - mapValue: + fields: + firstTag: + functionValue: + args: + - fieldReferenceValue: tags + - integerValue: '0' + name: array_get + name: select + - description: testArrayGet_NegativeOffset + pipeline: + - Collection: books + - Where: + - FunctionExpression.equal: + - Field: title + - Constant: "The Hitchhiker's Guide to the Galaxy" + - Select: + - AliasedExpression: + - FunctionExpression.array_get: + - Field: tags + - Constant: -1 + - "lastTag" + assert_results: + - lastTag: "adventure" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: "The Hitchhiker's Guide to the Galaxy" + name: equal + name: where + - args: + - mapValue: + fields: + lastTag: + functionValue: + args: + - fieldReferenceValue: tags + - integerValue: '-1' + name: array_get + name: select \ No newline at end of file diff --git a/tests/system/pipeline_e2e/data.yaml b/tests/system/pipeline_e2e/data.yaml new file mode 100644 index 000000000..f2533d2b1 --- /dev/null +++ b/tests/system/pipeline_e2e/data.yaml @@ -0,0 +1,147 @@ +data: + books: + book1: + title: "The Hitchhiker's Guide to the Galaxy" + author: "Douglas Adams" + genre: "Science Fiction" + published: 1979 + rating: 4.2 + tags: + - comedy + - space + - adventure + awards: + hugo: true + nebula: false + book2: + title: "Pride and Prejudice" + author: "Jane Austen" + genre: "Romance" + published: 1813 + rating: 4.5 + tags: + - classic + - social commentary + - love + awards: + none: true + book3: + title: "One Hundred Years of Solitude" + author: "Gabriel García Márquez" + genre: "Magical Realism" + published: 1967 + rating: 4.3 + tags: + - family + - history + - fantasy + awards: + nobel: true + nebula: false + book4: + title: "The Lord of the Rings" + author: "J.R.R. Tolkien" + genre: "Fantasy" + published: 1954 + rating: 4.7 + tags: + - adventure + - magic + - epic + awards: + hugo: false + nebula: false + book5: + title: "The Handmaid's Tale" + author: "Margaret Atwood" + genre: "Dystopian" + published: 1985 + rating: 4.1 + tags: + - feminism + - totalitarianism + - resistance + awards: + arthur c. clarke: true + booker prize: false + book6: + title: "Crime and Punishment" + author: "Fyodor Dostoevsky" + genre: "Psychological Thriller" + published: 1866 + rating: 4.3 + tags: + - philosophy + - crime + - redemption + awards: + none: true + book7: + title: "To Kill a Mockingbird" + author: "Harper Lee" + genre: "Southern Gothic" + published: 1960 + rating: 4.2 + tags: + - racism + - injustice + - coming-of-age + awards: + pulitzer: true + book8: + title: "1984" + author: "George Orwell" + genre: "Dystopian" + published: 1949 + rating: 4.2 + tags: + - surveillance + - totalitarianism + - propaganda + awards: + prometheus: true + book9: + title: "The Great Gatsby" + author: "F. Scott Fitzgerald" + genre: "Modernist" + published: 1925 + rating: 4.0 + tags: + - wealth + - american dream + - love + awards: + none: true + book10: + title: "Dune" + author: "Frank Herbert" + genre: "Science Fiction" + published: 1965 + rating: 4.6 + tags: + - politics + - desert + - ecology + awards: + hugo: true + nebula: true + timestamps: + ts1: + time: "1993-04-28T12:01:00.654321+00:00" + micros: 735998460654321 + millis: 735998460654 + seconds: 735998460 + vectors: + vec1: + embedding: [1.0, 2.0, 3.0] + vec2: + embedding: [4.0, 5.0, 6.0, 7.0] + vec3: + embedding: [5.0, 6.0, 7.0] + vec4: + embedding: [1.0, 2.0, 4.0] + errors: + doc_with_nan: + value: "NaN" + doc_with_null: + value: null \ No newline at end of file diff --git a/tests/system/pipeline_e2e/date_and_time.yaml b/tests/system/pipeline_e2e/date_and_time.yaml new file mode 100644 index 000000000..2319b333b --- /dev/null +++ b/tests/system/pipeline_e2e/date_and_time.yaml @@ -0,0 +1,103 @@ +tests: + - description: testCurrentTimestamp + pipeline: + - Collection: books + - Limit: 1 + - Select: + - AliasedExpression: + - And: + - FunctionExpression.greater_than_or_equal: + - CurrentTimestamp: [] + - FunctionExpression.unix_seconds_to_timestamp: + - Constant: 1735689600 # 2025-01-01 + - FunctionExpression.less_than: + - CurrentTimestamp: [] + - FunctionExpression.unix_seconds_to_timestamp: + - Constant: 4892438400 # 2125-01-01 + - "is_between_2025_and_2125" + assert_results: + - is_between_2025_and_2125: true + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - integerValue: '1' + name: limit + - args: + - mapValue: + fields: + is_between_2025_and_2125: + functionValue: + name: and + args: + - functionValue: + name: greater_than_or_equal + args: + - functionValue: + name: current_timestamp + - functionValue: + name: unix_seconds_to_timestamp + args: + - integerValue: '1735689600' + - functionValue: + name: less_than + args: + - functionValue: + name: current_timestamp + - functionValue: + name: unix_seconds_to_timestamp + args: + - integerValue: '4892438400' + name: select + - description: testTimestampFunctionExpressions + pipeline: + - Collection: timestamps + - Select: + - AliasedExpression: + - FunctionExpression.timestamp_to_unix_micros: + - Field: time + - "micros" + - AliasedExpression: + - FunctionExpression.timestamp_to_unix_millis: + - Field: time + - "millis" + - AliasedExpression: + - FunctionExpression.timestamp_to_unix_seconds: + - Field: time + - "seconds" + - AliasedExpression: + - FunctionExpression.unix_micros_to_timestamp: + - Field: micros + - "from_micros" + - AliasedExpression: + - FunctionExpression.unix_millis_to_timestamp: + - Field: millis + - "from_millis" + - AliasedExpression: + - FunctionExpression.unix_seconds_to_timestamp: + - Field: seconds + - "from_seconds" + - AliasedExpression: + - FunctionExpression.timestamp_add: + - Field: time + - Constant: "day" + - Constant: 1 + - "plus_day" + - AliasedExpression: + - FunctionExpression.timestamp_subtract: + - Field: time + - Constant: "hour" + - Constant: 1 + - "minus_hour" + assert_results: + - micros: 735998460654321 + millis: 735998460654 + seconds: 735998460 + from_micros: "1993-04-28T12:01:00.654321+00:00" + from_millis: "1993-04-28T12:01:00.654000+00:00" + from_seconds: "1993-04-28T12:01:00.000000+00:00" + plus_day: "1993-04-29T12:01:00.654321+00:00" + minus_hour: "1993-04-28T11:01:00.654321+00:00" diff --git a/tests/system/pipeline_e2e/general.yaml b/tests/system/pipeline_e2e/general.yaml new file mode 100644 index 000000000..46a10cd4d --- /dev/null +++ b/tests/system/pipeline_e2e/general.yaml @@ -0,0 +1,687 @@ +tests: + - description: selectSpecificFields + pipeline: + - Collection: books + - Select: + - title + - author + - Sort: + - Ordering: + - Field: author + - ASCENDING + assert_results: + - title: "The Hitchhiker's Guide to the Galaxy" + author: "Douglas Adams" + - title: "The Great Gatsby" + author: "F. Scott Fitzgerald" + - title: "Dune" + author: "Frank Herbert" + - title: "Crime and Punishment" + author: "Fyodor Dostoevsky" + - title: "One Hundred Years of Solitude" + author: "Gabriel García Márquez" + - title: "1984" + author: "George Orwell" + - title: "To Kill a Mockingbird" + author: "Harper Lee" + - title: "The Lord of the Rings" + author: "J.R.R. Tolkien" + - title: "Pride and Prejudice" + author: "Jane Austen" + - title: "The Handmaid's Tale" + author: "Margaret Atwood" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + author: + fieldReferenceValue: author + title: + fieldReferenceValue: title + name: select + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: author + name: sort + - description: addAndRemoveFields + pipeline: + - Collection: books + - AddFields: + - AliasedExpression: + - FunctionExpression.string_concat: + - Field: author + - Constant: _ + - Field: title + - "author_title" + - AliasedExpression: + - FunctionExpression.string_concat: + - Field: title + - Constant: _ + - Field: author + - "title_author" + - RemoveFields: + - title_author + - tags + - awards + - rating + - title + - Field: published + - Field: genre + - Field: nestedField # Field does not exist, should be ignored + - Sort: + - Ordering: + - Field: author_title + - ASCENDING + assert_results: + - author: Douglas Adams + author_title: Douglas Adams_The Hitchhiker's Guide to the Galaxy + - author: F. Scott Fitzgerald + author_title: F. Scott Fitzgerald_The Great Gatsby + - author: Frank Herbert + author_title: Frank Herbert_Dune + - author: Fyodor Dostoevsky + author_title: Fyodor Dostoevsky_Crime and Punishment + - author: Gabriel García Márquez + author_title: Gabriel García Márquez_One Hundred Years of Solitude + - author: George Orwell + author_title: George Orwell_1984 + - author: Harper Lee + author_title: Harper Lee_To Kill a Mockingbird + - author: J.R.R. Tolkien + author_title: J.R.R. Tolkien_The Lord of the Rings + - author: Jane Austen + author_title: Jane Austen_Pride and Prejudice + - author: Margaret Atwood + author_title: Margaret Atwood_The Handmaid's Tale + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + author_title: + functionValue: + args: + - fieldReferenceValue: author + - stringValue: _ + - fieldReferenceValue: title + name: string_concat + title_author: + functionValue: + args: + - fieldReferenceValue: title + - stringValue: _ + - fieldReferenceValue: author + name: string_concat + name: add_fields + - args: + - fieldReferenceValue: title_author + - fieldReferenceValue: tags + - fieldReferenceValue: awards + - fieldReferenceValue: rating + - fieldReferenceValue: title + - fieldReferenceValue: published + - fieldReferenceValue: genre + - fieldReferenceValue: nestedField + name: remove_fields + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: author_title + name: sort + - description: testPipelineWithOffsetAndLimit + pipeline: + - Collection: books + - Sort: + - Ordering: + - Field: author + - ASCENDING + - Offset: 5 + - Limit: 3 + - Select: + - title + - author + assert_results: + - title: "1984" + author: George Orwell + - title: To Kill a Mockingbird + author: Harper Lee + - title: The Lord of the Rings + author: J.R.R. Tolkien + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: author + name: sort + - args: + - integerValue: '5' + name: offset + - args: + - integerValue: '3' + name: limit + - args: + - mapValue: + fields: + author: + fieldReferenceValue: author + title: + fieldReferenceValue: title + name: select + - description: testSampleLimit + pipeline: + - Collection: books + - Sample: 3 + assert_count: 3 # Results will vary due to randomness + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - integerValue: '3' + - stringValue: documents + name: sample + - description: testSamplePercentage + pipeline: + - Collection: books + - Sample: + - SampleOptions: + - 0.6 + - percent + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - doubleValue: 0.6 + - stringValue: percent + name: sample + - description: testUnion + pipeline: + - Collection: books + - Where: + - FunctionExpression.equal: + - Field: genre + - Constant: Romance + - Union: + - Pipeline: + - Collection: books + - Where: + - FunctionExpression.equal: + - Field: genre + - Constant: Dystopian + - Select: + - title + - Sort: + - Ordering: + - Field: title + - ASCENDING + assert_results: + - title: "1984" + - title: Pride and Prejudice + - title: "The Handmaid's Tale" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: genre + - stringValue: Romance + name: equal + name: where + - args: + - pipelineValue: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: genre + - stringValue: Dystopian + name: equal + name: where + name: union + - args: + - mapValue: + fields: + title: + fieldReferenceValue: title + name: select + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: title + name: sort + - description: testUnionFullCollection + pipeline: + - Collection: books + - Union: + - Pipeline: + - Collection: books + assert_count: 20 # Results will be duplicated + - description: testDocumentId + pipeline: + - Collection: books + - Where: + - FunctionExpression.equal: + - Field: title + - Constant: "The Hitchhiker's Guide to the Galaxy" + - Select: + - AliasedExpression: + - FunctionExpression.document_id: + - Field: __name__ + - "doc_id" + assert_results: + - doc_id: "book1" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: "The Hitchhiker's Guide to the Galaxy" + name: equal + name: where + - args: + - mapValue: + fields: + doc_id: + functionValue: + name: document_id + args: + - fieldReferenceValue: __name__ + name: select + - description: testCollectionId + pipeline: + - Collection: books + - Limit: 1 + - Select: + - AliasedExpression: + - FunctionExpression.collection_id: + - Field: __name__ + - "collectionName" + assert_results: + - collectionName: "books" + - description: testCollectionGroup + pipeline: + - CollectionGroup: books + - Select: + - title + - Distinct: + - title + - Sort: + - Ordering: + - Field: title + - ASCENDING + assert_results: + - title: "1984" + - title: "Crime and Punishment" + - title: "Dune" + - title: "One Hundred Years of Solitude" + - title: "Pride and Prejudice" + - title: "The Great Gatsby" + - title: "The Handmaid's Tale" + - title: "The Hitchhiker's Guide to the Galaxy" + - title: "The Lord of the Rings" + - title: "To Kill a Mockingbird" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: '' + - stringValue: books + name: collection_group + - args: + - mapValue: + fields: + title: + fieldReferenceValue: title + name: select + - args: + - mapValue: + fields: + title: + fieldReferenceValue: title + name: distinct + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: title + name: sort + + - description: testDistinct + pipeline: + - Collection: books + - Distinct: + - genre + - Sort: + - Ordering: + - Field: genre + - ASCENDING + assert_results: + - genre: Dystopian + - genre: Fantasy + - genre: Magical Realism + - genre: Modernist + - genre: Psychological Thriller + - genre: Romance + - genre: Science Fiction + - genre: Southern Gothic + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + genre: + fieldReferenceValue: genre + name: distinct + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: genre + name: sort + + - description: testDocuments + pipeline: + - Documents: + - /books/book1 + - /books/book10 + - Select: + - title + - Sort: + - Ordering: + - Field: title + - ASCENDING + assert_results: + - title: "Dune" + - title: "The Hitchhiker's Guide to the Galaxy" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books/book1 + - referenceValue: /books/book10 + name: documents + - args: + - mapValue: + fields: + title: + fieldReferenceValue: title + name: select + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: title + name: sort + + - description: testDatabase + pipeline: + - Database + - Select: + - title + - Distinct: + - title + - Aggregate: + - AliasedExpression: + - Count: [] + - count + - Select: + - AliasedExpression: + - Conditional: + - FunctionExpression.greater_than_or_equal: + - Field: count + - Constant: 10 + - Constant: True + - Constant: False + - result + assert_results: + - result: True + assert_proto: + pipeline: + stages: + - name: database + - args: + - mapValue: + fields: + title: + fieldReferenceValue: title + name: select + - args: + - mapValue: + fields: + title: + fieldReferenceValue: title + name: distinct + - args: + - mapValue: + fields: + count: + functionValue: + name: count + - mapValue: {} + name: aggregate + - name: select + args: + - mapValue: + fields: + result: + functionValue: + args: + - functionValue: + args: + - fieldReferenceValue: count + - integerValue: '10' + name: greater_than_or_equal + - booleanValue: true + - booleanValue: false + name: conditional + - description: testRawStage + pipeline: + - RawStage: + - "collection" + - Value: + reference_value: "/books" + - RawStage: + - "where" + - FunctionExpression.equal: + - Field: title + - Constant: The Hitchhiker's Guide to the Galaxy + - RawStage: + - "select" + - Value: + map_value: + fields: + author: + field_reference_value: author + assert_results: + - author: Douglas Adams + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: The Hitchhiker's Guide to the Galaxy + name: equal + name: where + - args: + - mapValue: + fields: + author: + fieldReferenceValue: author + name: select + - description: testUnnest + pipeline: + - Collection: books + - Where: + - FunctionExpression.equal: + - Field: title + - Constant: The Hitchhiker's Guide to the Galaxy + - Unnest: + - tags + - tags_alias + - Select: tags_alias + assert_results: + - tags_alias: comedy + - tags_alias: space + - tags_alias: adventure + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: The Hitchhiker's Guide to the Galaxy + name: equal + name: where + - args: + - fieldReferenceValue: tags + - fieldReferenceValue: tags_alias + name: unnest + - args: + - mapValue: + fields: + tags_alias: + fieldReferenceValue: tags_alias + name: select + - description: testUnnestWithOptions + pipeline: + - Collection: books + - Where: + - FunctionExpression.equal: + - Field: title + - Constant: The Hitchhiker's Guide to the Galaxy + - Unnest: + field: tags + alias: tags_alias + options: + UnnestOptions: + - index + - Select: + - tags_alias + - index + assert_results: + - tags_alias: comedy + index: 0 + - tags_alias: space + index: 1 + - tags_alias: adventure + index: 2 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: The Hitchhiker's Guide to the Galaxy + name: equal + name: where + - args: + - fieldReferenceValue: tags + - fieldReferenceValue: tags_alias + name: unnest + options: + index_field: + fieldReferenceValue: index + - args: + - mapValue: + fields: + tags_alias: + fieldReferenceValue: tags_alias + index: + fieldReferenceValue: index + name: select + - description: replaceWith + pipeline: + - Collection: books + - Where: + - FunctionExpression.equal: + - Field: title + - Constant: "The Hitchhiker's Guide to the Galaxy" + - ReplaceWith: + - Field: awards + assert_results: + - hugo: True + nebula: False + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: "The Hitchhiker's Guide to the Galaxy" + name: equal + name: where + - args: + - fieldReferenceValue: awards + - stringValue: full_replace + name: replace_with \ No newline at end of file diff --git a/tests/system/pipeline_e2e/logical.yaml b/tests/system/pipeline_e2e/logical.yaml new file mode 100644 index 000000000..296cfda14 --- /dev/null +++ b/tests/system/pipeline_e2e/logical.yaml @@ -0,0 +1,690 @@ +tests: + - description: whereByMultipleConditions + pipeline: + - Collection: books + - Where: + - And: + - FunctionExpression.greater_than: + - Field: rating + - Constant: 4.5 + - FunctionExpression.equal: + - Field: genre + - Constant: Science Fiction + assert_results: + - title: Dune + author: Frank Herbert + genre: Science Fiction + published: 1965 + rating: 4.6 + tags: + - politics + - desert + - ecology + awards: + hugo: true + nebula: true + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - functionValue: + args: + - fieldReferenceValue: rating + - doubleValue: 4.5 + name: greater_than + - functionValue: + args: + - fieldReferenceValue: genre + - stringValue: Science Fiction + name: equal + name: and + name: where + - description: whereByOrCondition + pipeline: + - Collection: books + - Where: + - Or: + - FunctionExpression.equal: + - Field: genre + - Constant: Romance + - FunctionExpression.equal: + - Field: genre + - Constant: Dystopian + - Select: + - title + - Sort: + - Ordering: + - Field: title + - ASCENDING + assert_results: + - title: "1984" + - title: Pride and Prejudice + - title: The Handmaid's Tale + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - functionValue: + args: + - fieldReferenceValue: genre + - stringValue: Romance + name: equal + - functionValue: + args: + - fieldReferenceValue: genre + - stringValue: Dystopian + name: equal + name: or + name: where + - args: + - mapValue: + fields: + title: + fieldReferenceValue: title + name: select + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: title + name: sort + - description: testComparisonOperators + pipeline: + - Collection: books + - Where: + - And: + - FunctionExpression.greater_than: + - Field: rating + - Constant: 4.2 + - FunctionExpression.less_than_or_equal: + - Field: rating + - Constant: 4.5 + - FunctionExpression.not_equal: + - Field: genre + - Constant: Science Fiction + - Select: + - rating + - title + - Sort: + - Ordering: + - title + - ASCENDING + assert_results: + - rating: 4.3 + title: Crime and Punishment + - rating: 4.3 + title: One Hundred Years of Solitude + - rating: 4.5 + title: Pride and Prejudice + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - functionValue: + args: + - fieldReferenceValue: rating + - doubleValue: 4.2 + name: greater_than + - functionValue: + args: + - fieldReferenceValue: rating + - doubleValue: 4.5 + name: less_than_or_equal + - functionValue: + args: + - fieldReferenceValue: genre + - stringValue: Science Fiction + name: not_equal + name: and + name: where + - args: + - mapValue: + fields: + rating: + fieldReferenceValue: rating + title: + fieldReferenceValue: title + name: select + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: title + name: sort + - description: testLogicalOperators + pipeline: + - Collection: books + - Where: + - Or: + - And: + - FunctionExpression.greater_than: + - Field: rating + - Constant: 4.5 + - FunctionExpression.equal: + - Field: genre + - Constant: Science Fiction + - FunctionExpression.less_than: + - Field: published + - Constant: 1900 + - Select: + - title + - Sort: + - Ordering: + - Field: title + - ASCENDING + assert_results: + - title: Crime and Punishment + - title: Dune + - title: Pride and Prejudice + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - functionValue: + args: + - functionValue: + args: + - fieldReferenceValue: rating + - doubleValue: 4.5 + name: greater_than + - functionValue: + args: + - fieldReferenceValue: genre + - stringValue: Science Fiction + name: equal + name: and + - functionValue: + args: + - fieldReferenceValue: published + - integerValue: '1900' + name: less_than + name: or + name: where + - args: + - mapValue: + fields: + title: + fieldReferenceValue: title + name: select + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: title + name: sort + - description: testIsNull + pipeline: + - Collection: errors + - Where: + - FunctionExpression.equal: + - Field: value + - null + assert_results: + - value: null + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /errors + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: value + - nullValue: null + name: equal + name: where + - description: testIsNan + pipeline: + - Collection: errors + - Where: + - FunctionExpression.equal: + - Field: value + - NaN + assert_count: 1 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /errors + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: value + - doubleValue: NaN + name: equal + name: where + - description: testIsAbsent + pipeline: + - Collection: books + - Where: + - FunctionExpression.is_absent: + - Field: awards.pulitzer + assert_count: 9 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: awards.pulitzer + name: is_absent + name: where + - description: testIfAbsent + pipeline: + - Collection: books + - Select: + - AliasedExpression: + - FunctionExpression.if_absent: + - Field: awards.pulitzer + - Constant: false + - "pulitzer_award" + - title + - Where: + - FunctionExpression.equal: + - Field: pulitzer_award + - Constant: true + assert_results: + - pulitzer_award: true + title: To Kill a Mockingbird + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + pulitzer_award: + functionValue: + name: if_absent + args: + - fieldReferenceValue: awards.pulitzer + - booleanValue: false + title: + fieldReferenceValue: title + name: select + - args: + - functionValue: + args: + - fieldReferenceValue: pulitzer_award + - booleanValue: true + name: equal + name: where + - description: testIsError + pipeline: + - Collection: books + - Select: + - AliasedExpression: + - FunctionExpression.is_error: + - FunctionExpression.divide: + - Field: rating + - Constant: "string" + - "is_error_result" + - Limit: 1 + assert_results: + - is_error_result: true + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + is_error_result: + functionValue: + name: is_error + args: + - functionValue: + name: divide + args: + - fieldReferenceValue: rating + - stringValue: "string" + name: select + - args: + - integerValue: '1' + name: limit + - description: testIfError + pipeline: + - Collection: books + - Select: + - AliasedExpression: + - FunctionExpression.if_error: + - FunctionExpression.divide: + - Field: rating + - Field: genre + - Constant: "An error occurred" + - "if_error_result" + - Limit: 1 + assert_results: + - if_error_result: "An error occurred" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + if_error_result: + functionValue: + name: if_error + args: + - functionValue: + name: divide + args: + - fieldReferenceValue: rating + - fieldReferenceValue: genre + - stringValue: "An error occurred" + name: select + - args: + - integerValue: '1' + name: limit + - description: testLogicalMinMax + pipeline: + - Collection: books + - Where: + - FunctionExpression.equal: + - Field: author + - Constant: Douglas Adams + - Select: + - AliasedExpression: + - FunctionExpression.logical_maximum: + - Field: rating + - Constant: 4.5 + - "max_rating" + - AliasedExpression: + - FunctionExpression.logical_minimum: + - Field: published + - Constant: 1900 + - "min_published" + assert_results: + - max_rating: 4.5 + min_published: 1900 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: author + - stringValue: Douglas Adams + name: equal + name: where + - args: + - mapValue: + fields: + min_published: + functionValue: + args: + - fieldReferenceValue: published + - integerValue: '1900' + name: minimum + max_rating: + functionValue: + args: + - fieldReferenceValue: rating + - doubleValue: 4.5 + name: maximum + name: select + - description: testLogicalMinMaxWithMultipleInputs + pipeline: + - Collection: books + - Where: + - FunctionExpression.equal: + - Field: author + - Constant: Douglas Adams + - Select: + - AliasedExpression: + - FunctionExpression.logical_maximum: + - Field: rating + - Constant: 4.5 + - Constant: 3.0 + - Constant: 5.0 + - "max_rating" + - AliasedExpression: + - FunctionExpression.logical_minimum: + - Field: published + - Constant: 1900 + - Constant: 2000 + - Constant: 1984 + - "min_published" + assert_results: + - max_rating: 5.0 + min_published: 1900 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: author + - stringValue: Douglas Adams + name: equal + name: where + - args: + - mapValue: + fields: + min_published: + functionValue: + args: + - fieldReferenceValue: published + - integerValue: '1900' + - integerValue: '2000' + - integerValue: '1984' + name: minimum + max_rating: + functionValue: + args: + - fieldReferenceValue: rating + - doubleValue: 4.5 + - doubleValue: 3.0 + - doubleValue: 5.0 + name: maximum + name: select + - description: testGreaterThanOrEqual + pipeline: + - Collection: books + - Where: + - FunctionExpression.greater_than_or_equal: + - Field: rating + - Constant: 4.6 + - Select: + - title + - rating + - Sort: + - Ordering: + - Field: rating + - ASCENDING + assert_results: + - title: Dune + rating: 4.6 + - title: The Lord of the Rings + rating: 4.7 + - description: testInAndNotIn + pipeline: + - Collection: books + - Where: + - And: + - FunctionExpression.equal_any: + - Field: genre + - - Constant: Romance + - Constant: Dystopian + - FunctionExpression.not_equal_any: + - Field: author + - - Constant: "George Orwell" + assert_results: + - title: "Pride and Prejudice" + author: "Jane Austen" + genre: "Romance" + published: 1813 + rating: 4.5 + tags: + - classic + - social commentary + - love + awards: + none: true + - title: "The Handmaid's Tale" + author: "Margaret Atwood" + genre: "Dystopian" + published: 1985 + rating: 4.1 + tags: + - feminism + - totalitarianism + - resistance + awards: + "arthur c. clarke": true + "booker prize": false + - description: testExists + pipeline: + - Collection: books + - Where: + - And: + - FunctionExpression.exists: + - Field: awards.pulitzer + - FunctionExpression.equal: + - Field: awards.pulitzer + - Constant: true + - Select: + - title + assert_results: + - title: To Kill a Mockingbird + - description: testXor + pipeline: + - Collection: books + - Where: + - Xor: + - - FunctionExpression.equal: + - Field: genre + - Constant: Romance + - FunctionExpression.greater_than: + - Field: published + - Constant: 1980 + - Select: + - title + - genre + - published + - Sort: + - Ordering: + - Field: title + - ASCENDING + assert_results: + - title: "Pride and Prejudice" + genre: "Romance" + published: 1813 + - title: "The Handmaid's Tale" + genre: "Dystopian" + published: 1985 + - description: testConditional + pipeline: + - Collection: books + - Select: + - title + - AliasedExpression: + - Conditional: + - FunctionExpression.greater_than: + - Field: published + - Constant: 1950 + - Constant: "Modern" + - Constant: "Classic" + - "era" + - Sort: + - Ordering: + - Field: title + - ASCENDING + - Limit: 4 + assert_results: + - title: "1984" + era: "Classic" + - title: "Crime and Punishment" + era: "Classic" + - title: "Dune" + era: "Modern" + - title: "One Hundred Years of Solitude" + era: "Modern" + - description: testFieldToFieldComparison + pipeline: + - Collection: books + - Where: + - FunctionExpression.greater_than: + - Field: published + - Field: rating + - Select: + - title + assert_count: 10 # All books were published after year 4.7 + - description: testExistsNegative + pipeline: + - Collection: books + - Where: + - FunctionExpression.exists: + - Field: non_existent_field + assert_count: 0 + - description: testConditionalWithFields + pipeline: + - Collection: books + - Where: + - FunctionExpression.equal_any: + - Field: title + - - Constant: "Dune" + - Constant: "1984" + - Select: + - title + - AliasedExpression: + - Conditional: + - FunctionExpression.greater_than: + - Field: published + - Constant: 1950 + - Field: author + - Field: genre + - "conditional_field" + - Sort: + - Ordering: + - Field: title + - ASCENDING + assert_results: + - title: "1984" + conditional_field: "Dystopian" + - title: "Dune" + conditional_field: "Frank Herbert" diff --git a/tests/system/pipeline_e2e/map.yaml b/tests/system/pipeline_e2e/map.yaml new file mode 100644 index 000000000..3e5e5de12 --- /dev/null +++ b/tests/system/pipeline_e2e/map.yaml @@ -0,0 +1,269 @@ +tests: + - description: testMapGet + pipeline: + - Collection: books + - Sort: + - Ordering: + - Field: published + - DESCENDING + - Select: + - AliasedExpression: + - FunctionExpression.map_get: + - Field: awards + - hugo + - "hugoAward" + - Field: title + - Where: + - FunctionExpression.equal: + - Field: hugoAward + - Constant: true + assert_results: + - hugoAward: true + title: The Hitchhiker's Guide to the Galaxy + - hugoAward: true + title: Dune + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + direction: + stringValue: descending + expression: + fieldReferenceValue: published + name: sort + - args: + - mapValue: + fields: + hugoAward: + functionValue: + args: + - fieldReferenceValue: awards + - stringValue: hugo + name: map_get + title: + fieldReferenceValue: title + name: select + - args: + - functionValue: + args: + - fieldReferenceValue: hugoAward + - booleanValue: true + name: equal + name: where + - description: testMapGetWithField + pipeline: + - Collection: books + - Where: + - FunctionExpression.equal: + - Field: title + - Constant: "Dune" + - AddFields: + - AliasedExpression: + - Constant: "hugo" + - "award_name" + - Select: + - AliasedExpression: + - FunctionExpression.map_get: + - Field: awards + - Field: award_name + - "hugoAward" + - Field: title + assert_results: + - hugoAward: true + title: Dune + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: "Dune" + name: equal + name: where + - args: + - mapValue: + fields: + award_name: + stringValue: "hugo" + name: add_fields + - args: + - mapValue: + fields: + hugoAward: + functionValue: + name: map_get + args: + - fieldReferenceValue: awards + - fieldReferenceValue: award_name + title: + fieldReferenceValue: title + name: select + - description: testMapRemove + pipeline: + - Collection: books + - Where: + - FunctionExpression.equal: + - Field: title + - Constant: "Dune" + - Select: + - AliasedExpression: + - FunctionExpression.map_remove: + - Field: awards + - "nebula" + - "awards_removed" + assert_results: + - awards_removed: + hugo: true + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: "Dune" + name: equal + name: where + - args: + - mapValue: + fields: + awards_removed: + functionValue: + name: map_remove + args: + - fieldReferenceValue: awards + - stringValue: "nebula" + name: select + - description: testMapMerge + pipeline: + - Collection: books + - Where: + - FunctionExpression.equal: + - Field: title + - Constant: "Dune" + - Select: + - AliasedExpression: + - FunctionExpression.map_merge: + - Field: awards + - Map: + elements: {"new_award": true, "hugo": false} + - Map: + elements: {"another_award": "yes"} + - "awards_merged" + assert_results: + - awards_merged: + hugo: false + nebula: true + new_award: true + another_award: "yes" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: "Dune" + name: equal + name: where + - args: + - mapValue: + fields: + awards_merged: + functionValue: + name: map_merge + args: + - fieldReferenceValue: awards + - functionValue: + name: map + args: + - stringValue: "new_award" + - booleanValue: true + - stringValue: "hugo" + - booleanValue: false + - functionValue: + name: map + args: + - stringValue: "another_award" + - stringValue: "yes" + name: select + - description: testNestedFields + pipeline: + - Collection: books + - Where: + - FunctionExpression.equal: + - Field: awards.hugo + - Constant: true + - Sort: + - Ordering: + - Field: title + - DESCENDING + - Select: + - title + - Field: awards.hugo + assert_results: + - title: The Hitchhiker's Guide to the Galaxy + awards.hugo: true + - title: Dune + awards.hugo: true + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: awards.hugo + - booleanValue: true + name: equal + name: where + - args: + - mapValue: + fields: + direction: + stringValue: descending + expression: + fieldReferenceValue: title + name: sort + - args: + - mapValue: + fields: + awards.hugo: + fieldReferenceValue: awards.hugo + title: + fieldReferenceValue: title + name: select + - description: testMapMergeLiterals + pipeline: + - Collection: books + - Limit: 1 + - Select: + - AliasedExpression: + - FunctionExpression.map_merge: + - Map: + elements: {"a": "orig", "b": "orig"} + - Map: + elements: {"b": "new", "c": "new"} + - "merged" + assert_results: + - merged: + a: "orig" + b: "new" + c: "new" diff --git a/tests/system/pipeline_e2e/math.yaml b/tests/system/pipeline_e2e/math.yaml new file mode 100644 index 000000000..4d35f746d --- /dev/null +++ b/tests/system/pipeline_e2e/math.yaml @@ -0,0 +1,309 @@ +tests: + - description: testFieldToFieldArithmetic + pipeline: + - Collection: books + - Where: + - FunctionExpression.equal: + - Field: title + - Constant: "Dune" + - Select: + - AliasedExpression: + - FunctionExpression.add: + - Field: published + - Field: rating + - "pub_plus_rating" + assert_results: + - pub_plus_rating: 1969.6 + - description: testMathFunctionExpressionessions + pipeline: + - Collection: books + - Where: + - FunctionExpression.equal: + - Field: title + - Constant: To Kill a Mockingbird + - Select: + - AliasedExpression: + - FunctionExpression.abs: + - Field: rating + - "abs_rating" + - AliasedExpression: + - FunctionExpression.ceil: + - Field: rating + - "ceil_rating" + - AliasedExpression: + - FunctionExpression.exp: + - Field: rating + - "exp_rating" + - AliasedExpression: + - FunctionExpression.floor: + - Field: rating + - "floor_rating" + - AliasedExpression: + - FunctionExpression.ln: + - Field: rating + - "ln_rating" + - AliasedExpression: + - FunctionExpression.log10: + - Field: rating + - "log_rating_base10" + - AliasedExpression: + - FunctionExpression.log: + - Field: rating + - Constant: 2 + - "log_rating_base2" + - AliasedExpression: + - FunctionExpression.pow: + - Field: rating + - Constant: 2 + - "pow_rating" + - AliasedExpression: + - FunctionExpression.sqrt: + - Field: rating + - "sqrt_rating" + assert_results_approximate: + - abs_rating: 4.2 + ceil_rating: 5.0 + exp_rating: 66.686331 + floor_rating: 4.0 + ln_rating: 1.4350845 + log_rating_base10: 0.623249 + log_rating_base2: 2.0704 + pow_rating: 17.64 + sqrt_rating: 2.049390 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: To Kill a Mockingbird + name: equal + name: where + - args: + - mapValue: + fields: + abs_rating: + functionValue: + args: + - fieldReferenceValue: rating + name: abs + ceil_rating: + functionValue: + args: + - fieldReferenceValue: rating + name: ceil + exp_rating: + functionValue: + args: + - fieldReferenceValue: rating + name: exp + floor_rating: + functionValue: + args: + - fieldReferenceValue: rating + name: floor + ln_rating: + functionValue: + args: + - fieldReferenceValue: rating + name: ln + log_rating_base10: + functionValue: + args: + - fieldReferenceValue: rating + name: log10 + log_rating_base2: + functionValue: + args: + - fieldReferenceValue: rating + - integerValue: '2' + name: log + pow_rating: + functionValue: + args: + - fieldReferenceValue: rating + - integerValue: '2' + name: pow + sqrt_rating: + functionValue: + args: + - fieldReferenceValue: rating + name: sqrt + name: select + - description: testRoundFunctionExpressionessions + pipeline: + - Collection: books + - Where: + - FunctionExpression.equal_any: + - Field: title + - - Constant: "To Kill a Mockingbird" # rating 4.2 + - Constant: "Pride and Prejudice" # rating 4.5 + - Constant: "The Lord of the Rings" # rating 4.7 + - Select: + - title + - AliasedExpression: + - FunctionExpression.round: + - Field: rating + - "round_rating" + - Sort: + - Ordering: + - Field: title + - ASCENDING + assert_results: + - title: "Pride and Prejudice" + round_rating: 5.0 + - title: "The Lord of the Rings" + round_rating: 5.0 + - title: "To Kill a Mockingbird" + round_rating: 4.0 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - functionValue: + args: + - stringValue: "To Kill a Mockingbird" + - stringValue: "Pride and Prejudice" + - stringValue: "The Lord of the Rings" + name: array + name: equal_any + name: where + - args: + - mapValue: + fields: + title: + fieldReferenceValue: title + round_rating: + functionValue: + args: + - fieldReferenceValue: rating + name: round + name: select + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: title + name: sort + - description: testArithmeticOperations + pipeline: + - Collection: books + - Where: + - FunctionExpression.equal: + - Field: title + - Constant: To Kill a Mockingbird + - Select: + - AliasedExpression: + - FunctionExpression.add: + - Field: rating + - Constant: 1 + - "ratingPlusOne" + - AliasedExpression: + - FunctionExpression.subtract: + - Field: published + - Constant: 1900 + - "yearsSince1900" + - AliasedExpression: + - FunctionExpression.multiply: + - Field: rating + - Constant: 10 + - "ratingTimesTen" + - AliasedExpression: + - FunctionExpression.divide: + - Field: rating + - Constant: 2 + - "ratingDividedByTwo" + - AliasedExpression: + - FunctionExpression.multiply: + - Field: rating + - Constant: 20 + - "ratingTimes20" + - AliasedExpression: + - FunctionExpression.add: + - Field: rating + - Constant: 3 + - "ratingPlus3" + - AliasedExpression: + - FunctionExpression.mod: + - Field: rating + - Constant: 2 + - "ratingMod2" + assert_results: + - ratingPlusOne: 5.2 + yearsSince1900: 60 + ratingTimesTen: 42.0 + ratingDividedByTwo: 2.1 + ratingTimes20: 84 + ratingPlus3: 7.2 + ratingMod2: 0.20000000000000018 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: To Kill a Mockingbird + name: equal + name: where + - args: + - mapValue: + fields: + ratingDividedByTwo: + functionValue: + args: + - fieldReferenceValue: rating + - integerValue: '2' + name: divide + ratingPlusOne: + functionValue: + args: + - fieldReferenceValue: rating + - integerValue: '1' + name: add + ratingTimesTen: + functionValue: + args: + - fieldReferenceValue: rating + - integerValue: '10' + name: multiply + yearsSince1900: + functionValue: + args: + - fieldReferenceValue: published + - integerValue: '1900' + name: subtract + ratingTimes20: + functionValue: + args: + - fieldReferenceValue: rating + - integerValue: '20' + name: multiply + ratingPlus3: + functionValue: + args: + - fieldReferenceValue: rating + - integerValue: '3' + name: add + ratingMod2: + functionValue: + args: + - fieldReferenceValue: rating + - integerValue: '2' + name: mod + name: select \ No newline at end of file diff --git a/tests/system/pipeline_e2e/string.yaml b/tests/system/pipeline_e2e/string.yaml new file mode 100644 index 000000000..20a97ba60 --- /dev/null +++ b/tests/system/pipeline_e2e/string.yaml @@ -0,0 +1,654 @@ +tests: + - description: testStringConcat + pipeline: + - Collection: books + - Sort: + - Ordering: + - Field: author + - ASCENDING + - Select: + - AliasedExpression: + - FunctionExpression.string_concat: + - Field: author + - Constant: " - " + - Field: title + - "bookInfo" + - Limit: 1 + assert_results: + - bookInfo: Douglas Adams - The Hitchhiker's Guide to the Galaxy + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: author + name: sort + - args: + - mapValue: + fields: + bookInfo: + functionValue: + args: + - fieldReferenceValue: author + - stringValue: ' - ' + - fieldReferenceValue: title + name: string_concat + name: select + - args: + - integerValue: '1' + name: limit + - description: testStartsWith + pipeline: + - Collection: books + - Where: + - FunctionExpression.starts_with: + - Field: title + - Constant: The + - Select: + - title + - Sort: + - Ordering: + - Field: title + - ASCENDING + assert_results: + - title: The Great Gatsby + - title: The Handmaid's Tale + - title: The Hitchhiker's Guide to the Galaxy + - title: The Lord of the Rings + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: The + name: starts_with + name: where + - args: + - mapValue: + fields: + title: + fieldReferenceValue: title + name: select + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: title + name: sort + - description: testEndsWith + pipeline: + - Collection: books + - Where: + - FunctionExpression.ends_with: + - Field: title + - Constant: y + - Select: + - title + - Sort: + - Ordering: + - Field: title + - DESCENDING + assert_results: + - title: The Hitchhiker's Guide to the Galaxy + - title: The Great Gatsby + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: y + name: ends_with + name: where + - args: + - mapValue: + fields: + title: + fieldReferenceValue: title + name: select + - args: + - mapValue: + fields: + direction: + stringValue: descending + expression: + fieldReferenceValue: title + name: sort + - description: testConcat + pipeline: + - Collection: books + - Where: + - FunctionExpression.equal: + - Field: title + - Constant: "The Hitchhiker's Guide to the Galaxy" + - Select: + - AliasedExpression: + - FunctionExpression.concat: + - Field: author + - Constant: ": " + - Field: title + - "author_title" + - AliasedExpression: + - FunctionExpression.concat: + - Field: tags + - - Constant: "new_tag" + - "concatenatedTags" + assert_results: + - author_title: "Douglas Adams: The Hitchhiker's Guide to the Galaxy" + concatenatedTags: + - comedy + - space + - adventure + - new_tag + - description: testLength + pipeline: + - Collection: books + - Where: + - FunctionExpression.equal: + - Field: title + - Constant: "The Hitchhiker's Guide to the Galaxy" + - Select: + - AliasedExpression: + - FunctionExpression.length: + - Field: title + - "titleLength" + - AliasedExpression: + - FunctionExpression.length: + - Field: tags + - "tagsLength" + - AliasedExpression: + - FunctionExpression.length: + - Field: awards + - "awardsLength" + assert_results: + - titleLength: 36 + tagsLength: 3 + awardsLength: 2 + - description: testCharLength + pipeline: + - Collection: books + - Select: + - AliasedExpression: + - FunctionExpression.char_length: + - Field: title + - "titleLength" + - title + - Where: + - FunctionExpression.greater_than: + - Field: titleLength + - Constant: 20 + - Sort: + - Ordering: + - Field: title + - ASCENDING + assert_results: + - titleLength: 29 + title: One Hundred Years of Solitude + - titleLength: 36 + title: The Hitchhiker's Guide to the Galaxy + - titleLength: 21 + title: The Lord of the Rings + - titleLength: 21 + title: To Kill a Mockingbird + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + title: + fieldReferenceValue: title + titleLength: + functionValue: + args: + - fieldReferenceValue: title + name: char_length + name: select + - args: + - functionValue: + args: + - fieldReferenceValue: titleLength + - integerValue: '20' + name: greater_than + name: where + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: title + name: sort + - description: CharLength + pipeline: + - Collection: books + - Where: + - FunctionExpression.equal: + - Field: author + - Constant: "Douglas Adams" + - Select: + - AliasedExpression: + - FunctionExpression.char_length: + - Field: title + - "title_length" + assert_results: + - title_length: 36 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: author + - stringValue: Douglas Adams + name: equal + name: where + - args: + - mapValue: + fields: + title_length: + functionValue: + args: + - fieldReferenceValue: title + name: char_length + name: select + - description: ByteLength + pipeline: + - Collection: books + - Where: + - FunctionExpression.equal: + - Field: author + - Constant: Douglas Adams + - Select: + - AliasedExpression: + - FunctionExpression.byte_length: + - FunctionExpression.string_concat: + - Field: title + - Constant: _银河系漫游指南 + - "title_byte_length" + assert_results: + - title_byte_length: 58 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: author + - stringValue: Douglas Adams + name: equal + name: where + - args: + - mapValue: + fields: + title_byte_length: + functionValue: + args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: "_\u94F6\u6CB3\u7CFB\u6F2B\u6E38\u6307\u5357" + name: string_concat + name: byte_length + name: select + - description: testLike + pipeline: + - Collection: books + - Where: + - FunctionExpression.like: + - Field: title + - Constant: "%Guide%" + - Select: + - title + assert_results: + - title: The Hitchhiker's Guide to the Galaxy + - description: testRegexContains + # Find titles that contain either "the" or "of" (case-insensitive) + pipeline: + - Collection: books + - Where: + - FunctionExpression.regex_contains: + - Field: title + - Constant: "(?i)(the|of)" + assert_count: 5 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: "(?i)(the|of)" + name: regex_contains + name: where + - description: testRegexMatches + # Find titles that contain either "the" or "of" (case-insensitive) + pipeline: + - Collection: books + - Where: + - FunctionExpression.regex_match: + - Field: title + - Constant: ".*(?i)(the|of).*" + assert_count: 5 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: ".*(?i)(the|of).*" + name: regex_match + name: where + - description: testStringContains + pipeline: + - Collection: books + - Where: + - FunctionExpression.string_contains: + - Field: title + - Constant: "Hitchhiker's" + - Select: + - title + assert_results: + - title: "The Hitchhiker's Guide to the Galaxy" + - description: ToLower + pipeline: + - Collection: books + - Where: + - FunctionExpression.equal: + - Field: author + - Constant: "Douglas Adams" + - Select: + - AliasedExpression: + - FunctionExpression.to_lower: + - Field: title + - "lower_title" + assert_results: + - lower_title: "the hitchhiker's guide to the galaxy" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: author + - stringValue: Douglas Adams + name: equal + name: where + - args: + - mapValue: + fields: + lower_title: + functionValue: + args: + - fieldReferenceValue: title + name: to_lower + name: select + - description: ToUpper + pipeline: + - Collection: books + - Where: + - FunctionExpression.equal: + - Field: author + - Constant: "Douglas Adams" + - Select: + - AliasedExpression: + - FunctionExpression.to_upper: + - Field: title + - "upper_title" + assert_results: + - upper_title: "THE HITCHHIKER'S GUIDE TO THE GALAXY" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: author + - stringValue: Douglas Adams + name: equal + name: where + - args: + - mapValue: + fields: + upper_title: + functionValue: + args: + - fieldReferenceValue: title + name: to_upper + name: select + - description: Trim + pipeline: + - Collection: books + - Where: + - FunctionExpression.equal: + - Field: author + - Constant: "Douglas Adams" + - Select: + - AliasedExpression: + - FunctionExpression.trim: + - FunctionExpression.string_concat: + - Constant: " " + - Field: title + - Constant: " " + - "trimmed_title" + assert_results: + - trimmed_title: "The Hitchhiker's Guide to the Galaxy" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: author + - stringValue: Douglas Adams + name: equal + name: where + - args: + - mapValue: + fields: + trimmed_title: + functionValue: + args: + - functionValue: + args: + - stringValue: " " + - fieldReferenceValue: title + - stringValue: " " + name: string_concat + name: trim + name: select + - description: StringReverse + pipeline: + - Collection: books + - Where: + - FunctionExpression.equal: + - Field: author + - Constant: "Jane Austen" + - Select: + - AliasedExpression: + - FunctionExpression.string_reverse: + - Field: title + - "reversed_title" + assert_results: + - reversed_title: "ecidujerP dna edirP" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: author + - stringValue: "Jane Austen" + name: equal + name: where + - args: + - mapValue: + fields: + reversed_title: + functionValue: + args: + - fieldReferenceValue: title + name: string_reverse + name: select + - description: Substring + pipeline: + - Collection: books + - Where: + - FunctionExpression.equal: + - Field: author + - Constant: "Douglas Adams" + - Select: + - AliasedExpression: + - FunctionExpression.substring: + - Field: title + - Constant: 4 + - Constant: 11 + - "substring_title" + assert_results: + - substring_title: "Hitchhiker'" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: author + - stringValue: "Douglas Adams" + name: equal + name: where + - args: + - mapValue: + fields: + substring_title: + functionValue: + args: + - fieldReferenceValue: title + - integerValue: '4' + - integerValue: '11' + name: substring + name: select + - description: Substring without length + pipeline: + - Collection: books + - Where: + - FunctionExpression.equal: + - Field: author + - Constant: "Fyodor Dostoevsky" + - Select: + - AliasedExpression: + - FunctionExpression.substring: + - Field: title + - Constant: 10 + - "substring_title" + assert_results: + - substring_title: "Punishment" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: author + - stringValue: "Fyodor Dostoevsky" + name: equal + name: where + - args: + - mapValue: + fields: + substring_title: + functionValue: + args: + - fieldReferenceValue: title + - integerValue: '10' + name: substring + name: select + - description: Join + pipeline: + - Collection: books + - Where: + - FunctionExpression.equal: + - Field: author + - Constant: "Douglas Adams" + - Select: + - AliasedExpression: + - FunctionExpression.join: + - Field: tags + - Constant: ", " + - "joined_tags" + assert_results: + - joined_tags: "comedy, space, adventure" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: author + - stringValue: "Douglas Adams" + name: equal + name: where + - args: + - mapValue: + fields: + joined_tags: + functionValue: + args: + - fieldReferenceValue: tags + - stringValue: ", " + name: join + name: select diff --git a/tests/system/pipeline_e2e/vector.yaml b/tests/system/pipeline_e2e/vector.yaml new file mode 100644 index 000000000..31df276b2 --- /dev/null +++ b/tests/system/pipeline_e2e/vector.yaml @@ -0,0 +1,160 @@ +tests: + - description: testVectorLength + pipeline: + - Collection: vectors + - Select: + - AliasedExpression: + - FunctionExpression.vector_length: + - Field: embedding + - "embedding_length" + - Sort: + - Ordering: + - Field: embedding_length + - ASCENDING + assert_results: + - embedding_length: 3 + - embedding_length: 3 + - embedding_length: 3 + - embedding_length: 4 + - description: testFindNearestEuclidean + pipeline: + - Collection: vectors + - FindNearest: + field: embedding + vector: [1.0, 2.0, 3.0] + distance_measure: EUCLIDEAN + options: + FindNearestOptions: + limit: 2 + distance_field: + Field: distance + - Select: + - distance + assert_results: + - distance: 0.0 + - distance: 1.0 + assert_proto: + pipeline: + stages: + - name: collection + args: + - referenceValue: /vectors + - name: find_nearest + args: + - fieldReferenceValue: embedding + - mapValue: + fields: + __type__: + stringValue: __vector__ + value: + arrayValue: + values: + - doubleValue: 1.0 + - doubleValue: 2.0 + - doubleValue: 3.0 + - stringValue: euclidean + options: + limit: + integerValue: '2' + distance_field: + fieldReferenceValue: distance + - name: select + args: + - mapValue: + fields: + distance: + fieldReferenceValue: distance + - description: testFindNearestDotProduct + pipeline: + - Collection: vectors + - FindNearest: + field: embedding + vector: [1.0, 2.0, 3.0] + distance_measure: DOT_PRODUCT + options: + FindNearestOptions: + limit: 3 + distance_field: + Field: distance + - Select: + - distance + assert_results: + - distance: 38.0 + - distance: 17.0 + - distance: 14.0 + assert_proto: + pipeline: + stages: + - name: collection + args: + - referenceValue: /vectors + - name: find_nearest + args: + - fieldReferenceValue: embedding + - mapValue: + fields: + __type__: + stringValue: __vector__ + value: + arrayValue: + values: + - doubleValue: 1.0 + - doubleValue: 2.0 + - doubleValue: 3.0 + - stringValue: dot_product + options: + limit: + integerValue: '3' + distance_field: + fieldReferenceValue: distance + - name: select + args: + - mapValue: + fields: + distance: + fieldReferenceValue: distance + - description: testDotProductWithConstant + pipeline: + - Collection: vectors + - Where: + - FunctionExpression.equal: + - Field: embedding + - Vector: [1.0, 2.0, 3.0] + - Select: + - AliasedExpression: + - FunctionExpression.dot_product: + - Field: embedding + - Vector: [1.0, 1.0, 1.0] + - "dot_product_result" + assert_results: + - dot_product_result: 6.0 + - description: testEuclideanDistanceWithConstant + pipeline: + - Collection: vectors + - Where: + - FunctionExpression.equal: + - Field: embedding + - Vector: [1.0, 2.0, 3.0] + - Select: + - AliasedExpression: + - FunctionExpression.euclidean_distance: + - Field: embedding + - Vector: [1.0, 2.0, 3.0] + - "euclidean_distance_result" + assert_results: + - euclidean_distance_result: 0.0 + - description: testCosineDistanceWithConstant + pipeline: + - Collection: vectors + - Where: + - FunctionExpression.equal: + - Field: embedding + - Vector: [1.0, 2.0, 3.0] + - Select: + - AliasedExpression: + - FunctionExpression.cosine_distance: + - Field: embedding + - Vector: [1.0, 2.0, 3.0] + - "cosine_distance_result" + assert_results: + - cosine_distance_result: 0.0 diff --git a/tests/system/test__helpers.py b/tests/system/test__helpers.py index d6ee9b944..8032ae119 100644 --- a/tests/system/test__helpers.py +++ b/tests/system/test__helpers.py @@ -10,7 +10,13 @@ RANDOM_ID_REGEX = re.compile("^[a-zA-Z0-9]{20}$") MISSING_DOCUMENT = "No document to update: " DOCUMENT_EXISTS = "Document already exists: " +ENTERPRISE_MODE_ERROR = "only allowed on ENTERPRISE mode" UNIQUE_RESOURCE_ID = unique_resource_id("-") EMULATOR_CREDS = EmulatorCreds() FIRESTORE_EMULATOR = os.environ.get(_FIRESTORE_EMULATOR_HOST) is not None FIRESTORE_OTHER_DB = os.environ.get("SYSTEM_TESTS_DATABASE", "system-tests-named-db") +FIRESTORE_ENTERPRISE_DB = os.environ.get("ENTERPRISE_DATABASE", "enterprise-db-native") + +# run all tests against default database, and a named database +TEST_DATABASES = [None, FIRESTORE_OTHER_DB] +TEST_DATABASES_W_ENTERPRISE = TEST_DATABASES + [FIRESTORE_ENTERPRISE_DB] diff --git a/tests/system/test_pipeline_acceptance.py b/tests/system/test_pipeline_acceptance.py new file mode 100644 index 000000000..2a83f4eaf --- /dev/null +++ b/tests/system/test_pipeline_acceptance.py @@ -0,0 +1,391 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This file loads and executes yaml-encoded test cases from pipeline_e2e.yaml +""" + +from __future__ import annotations +import os +import datetime +import pytest +import yaml +import re +from typing import Any + +from google.protobuf.json_format import MessageToDict + +from google.cloud.firestore_v1 import pipeline_stages as stages +from google.cloud.firestore_v1 import pipeline_expressions +from google.cloud.firestore_v1.vector import Vector +from google.cloud.firestore_v1 import pipeline_expressions as expr +from google.api_core.exceptions import GoogleAPIError + +from google.cloud.firestore import Client, AsyncClient + +from test__helpers import FIRESTORE_ENTERPRISE_DB +from test__helpers import FIRESTORE_EMULATOR + +FIRESTORE_PROJECT = os.environ.get("GCLOUD_PROJECT") + +pytestmark = pytest.mark.skipif( + condition=FIRESTORE_EMULATOR, + reason="Pipeline tests are currently not supported by emulator", +) + +test_dir_name = os.path.dirname(__file__) + +id_format = ( + lambda x: f"{x.get('file_name', '')}: {x.get('description', '')}" +) # noqa: E731 + + +def yaml_loader(field="tests", dir_name="pipeline_e2e", attach_file_name=True): + """ + Helper to load test cases or data from yaml file + """ + combined_yaml = None + for file_name in os.listdir(f"{test_dir_name}/{dir_name}"): + with open(f"{test_dir_name}/{dir_name}/{file_name}") as f: + new_yaml = yaml.safe_load(f) + assert new_yaml is not None, f"found empty yaml in {file_name}" + extracted = new_yaml.get(field, None) + # attach file_name field + if attach_file_name: + if isinstance(extracted, list): + for item in extracted: + item["file_name"] = file_name + elif isinstance(extracted, dict): + extracted["file_name"] = file_name + # aggregate files + if not combined_yaml: + combined_yaml = extracted + elif isinstance(combined_yaml, dict) and extracted: + combined_yaml.update(extracted) + elif isinstance(combined_yaml, list) and extracted: + combined_yaml.extend(extracted) + return combined_yaml + + +@pytest.mark.parametrize( + "test_dict", + [t for t in yaml_loader() if "assert_proto" in t], + ids=id_format, +) +def test_pipeline_parse_proto(test_dict, client): + """ + Finds assert_proto statements in yaml, and compares generated proto against expected value + """ + expected_proto = test_dict.get("assert_proto", None) + pipeline = parse_pipeline(client, test_dict["pipeline"]) + # check if proto matches as expected + if expected_proto: + got_proto = MessageToDict(pipeline._to_pb()._pb) + assert yaml.dump(expected_proto) == yaml.dump(got_proto) + + +@pytest.mark.parametrize( + "test_dict", + [t for t in yaml_loader() if "assert_error" in t], + ids=id_format, +) +def test_pipeline_expected_errors(test_dict, client): + """ + Finds assert_error statements in yaml, and ensures the pipeline raises the expected error + """ + error_regex = test_dict["assert_error"] + pipeline = parse_pipeline(client, test_dict["pipeline"]) + # check if server responds as expected + with pytest.raises(GoogleAPIError) as err: + pipeline.execute() + found_error = str(err.value) + match = re.search(error_regex, found_error) + assert match, f"error '{found_error}' does not match '{error_regex}'" + + +@pytest.mark.parametrize( + "test_dict", + [ + t + for t in yaml_loader() + if "assert_results" in t + or "assert_count" in t + or "assert_results_approximate" in t + ], + ids=id_format, +) +def test_pipeline_results(test_dict, client): + """ + Ensure pipeline returns expected results + """ + expected_results = _parse_yaml_types(test_dict.get("assert_results", None)) + expected_approximate_results = _parse_yaml_types( + test_dict.get("assert_results_approximate", None) + ) + expected_count = test_dict.get("assert_count", None) + pipeline = parse_pipeline(client, test_dict["pipeline"]) + # check if server responds as expected + got_results = [snapshot.data() for snapshot in pipeline.stream()] + if expected_results: + assert got_results == expected_results + if expected_approximate_results: + assert len(got_results) == len( + expected_approximate_results + ), "got unexpected result count" + for idx in range(len(got_results)): + assert got_results[idx] == pytest.approx( + expected_approximate_results[idx], abs=1e-4 + ) + if expected_count is not None: + assert len(got_results) == expected_count + + +@pytest.mark.parametrize( + "test_dict", + [t for t in yaml_loader() if "assert_error" in t], + ids=id_format, +) +@pytest.mark.asyncio +async def test_pipeline_expected_errors_async(test_dict, async_client): + """ + Finds assert_error statements in yaml, and ensures the pipeline raises the expected error + """ + error_regex = test_dict["assert_error"] + pipeline = parse_pipeline(async_client, test_dict["pipeline"]) + # check if server responds as expected + with pytest.raises(GoogleAPIError) as err: + await pipeline.execute() + found_error = str(err.value) + match = re.search(error_regex, found_error) + assert match, f"error '{found_error}' does not match '{error_regex}'" + + +@pytest.mark.parametrize( + "test_dict", + [ + t + for t in yaml_loader() + if "assert_results" in t + or "assert_count" in t + or "assert_results_approximate" in t + ], + ids=id_format, +) +@pytest.mark.asyncio +async def test_pipeline_results_async(test_dict, async_client): + """ + Ensure pipeline returns expected results + """ + expected_results = _parse_yaml_types(test_dict.get("assert_results", None)) + expected_approximate_results = _parse_yaml_types( + test_dict.get("assert_results_approximate", None) + ) + expected_count = test_dict.get("assert_count", None) + pipeline = parse_pipeline(async_client, test_dict["pipeline"]) + # check if server responds as expected + got_results = [snapshot.data() async for snapshot in pipeline.stream()] + if expected_results: + assert got_results == expected_results + if expected_approximate_results: + assert len(got_results) == len( + expected_approximate_results + ), "got unexpected result count" + for idx in range(len(got_results)): + assert got_results[idx] == pytest.approx( + expected_approximate_results[idx], abs=1e-4 + ) + if expected_count is not None: + assert len(got_results) == expected_count + + +################################################################################# +# Helpers & Fixtures +################################################################################# + + +def parse_pipeline(client, pipeline: list[dict[str, Any], str]): + """ + parse a yaml list of pipeline stages into firestore._pipeline_stages.Stage classes + """ + result_list = [] + for stage in pipeline: + # stage will be either a map of the stage_name and its args, or just the stage_name itself + stage_name: str = stage if isinstance(stage, str) else list(stage.keys())[0] + stage_cls: type[stages.Stage] = getattr(stages, stage_name) + # find arguments if given + if isinstance(stage, dict): + stage_yaml_args = stage[stage_name] + stage_obj = _apply_yaml_args_to_callable(stage_cls, client, stage_yaml_args) + else: + # yaml has no arguments + stage_obj = stage_cls() + result_list.append(stage_obj) + return client._pipeline_cls._create_with_stages(client, *result_list) + + +def _parse_expressions(client, yaml_element: Any): + """ + Turn yaml objects into pipeline expressions or native python object arguments + """ + if isinstance(yaml_element, list): + return [_parse_expressions(client, v) for v in yaml_element] + elif isinstance(yaml_element, dict): + if len(yaml_element) == 1 and _is_expr_string(next(iter(yaml_element))): + # build pipeline expressions if possible + cls_str = next(iter(yaml_element)) + callable_obj = None + if "." in cls_str: + cls_name, method_name = cls_str.split(".") + cls = getattr(pipeline_expressions, cls_name) + callable_obj = getattr(cls, method_name) + else: + callable_obj = getattr(pipeline_expressions, cls_str) + yaml_args = yaml_element[cls_str] + return _apply_yaml_args_to_callable(callable_obj, client, yaml_args) + elif len(yaml_element) == 1 and _is_stage_string(next(iter(yaml_element))): + # build pipeline stage if possible (eg, for SampleOptions) + cls_str = next(iter(yaml_element)) + cls = getattr(stages, cls_str) + yaml_args = yaml_element[cls_str] + return _apply_yaml_args_to_callable(cls, client, yaml_args) + elif len(yaml_element) == 1 and list(yaml_element)[0] == "Pipeline": + # find Pipeline objects for Union expressions + other_ppl = yaml_element["Pipeline"] + return parse_pipeline(client, other_ppl) + else: + # otherwise, return dict + return { + _parse_expressions(client, k): _parse_expressions(client, v) + for k, v in yaml_element.items() + } + elif _is_expr_string(yaml_element): + return getattr(pipeline_expressions, yaml_element)() + elif yaml_element == "NaN": + return float(yaml_element) + else: + return yaml_element + + +def _apply_yaml_args_to_callable(callable_obj, client, yaml_args): + """ + Helper to instantiate a class with yaml arguments. The arguments will be applied + as positional or keyword arguments, based on type + """ + if isinstance(yaml_args, dict): + return callable_obj(**_parse_expressions(client, yaml_args)) + elif isinstance(yaml_args, list) and not ( + callable_obj == expr.Constant + or callable_obj == Vector + or callable_obj == expr.Array + ): + # yaml has an array of arguments. Treat as args + return callable_obj(*_parse_expressions(client, yaml_args)) + else: + # yaml has a single argument + return callable_obj(_parse_expressions(client, yaml_args)) + + +def _is_expr_string(yaml_str): + """ + Returns true if a string represents a class in pipeline_expressions + """ + if isinstance(yaml_str, str) and "." in yaml_str: + parts = yaml_str.split(".") + if len(parts) == 2: + cls_name, method_name = parts + if hasattr(pipeline_expressions, cls_name): + cls = getattr(pipeline_expressions, cls_name) + if hasattr(cls, method_name): + return True + return ( + isinstance(yaml_str, str) + and yaml_str[0].isupper() + and hasattr(pipeline_expressions, yaml_str) + ) + + +def _is_stage_string(yaml_str): + """ + Returns true if a string represents a class in pipeline_stages + """ + return ( + isinstance(yaml_str, str) + and yaml_str[0].isupper() + and hasattr(stages, yaml_str) + ) + + +@pytest.fixture(scope="module") +def event_loop(): + """Change event_loop fixture to module level.""" + import asyncio + + policy = asyncio.get_event_loop_policy() + loop = policy.new_event_loop() + yield loop + loop.close() + + +def _parse_yaml_types(data): + """helper to convert yaml data to firestore objects when needed""" + if isinstance(data, dict): + return {key: _parse_yaml_types(value) for key, value in data.items()} + if isinstance(data, list): + # detect vectors + if all([isinstance(d, float) for d in data]): + return Vector(data) + else: + return [_parse_yaml_types(value) for value in data] + # detect timestamps + if isinstance(data, str) and ":" in data: + try: + parsed_datetime = datetime.datetime.fromisoformat(data) + return parsed_datetime + except ValueError: + pass + if data == "NaN": + return float("NaN") + return data + + +@pytest.fixture(scope="module") +def client(): + """ + Build a client to use for requests + """ + client = Client(project=FIRESTORE_PROJECT, database=FIRESTORE_ENTERPRISE_DB) + data = yaml_loader("data", attach_file_name=False) + to_delete = [] + try: + # setup data + batch = client.batch() + for collection_name, documents in data.items(): + collection_ref = client.collection(collection_name) + for document_id, document_data in documents.items(): + document_ref = collection_ref.document(document_id) + to_delete.append(document_ref) + batch.set(document_ref, _parse_yaml_types(document_data)) + batch.commit() + yield client + finally: + # clear data + for document_ref in to_delete: + document_ref.delete() + + +@pytest.fixture(scope="module") +def async_client(client): + """ + Build an async client to use for AsyncPipeline requests + """ + yield AsyncClient(project=client.project, database=client._database) diff --git a/tests/system/test_system.py b/tests/system/test_system.py index c66340de1..0c86c69a3 100644 --- a/tests/system/test_system.py +++ b/tests/system/test_system.py @@ -21,6 +21,7 @@ import google.auth import pytest +import mock from google.api_core.exceptions import ( AlreadyExists, FailedPrecondition, @@ -38,11 +39,14 @@ EMULATOR_CREDS, FIRESTORE_CREDS, FIRESTORE_EMULATOR, - FIRESTORE_OTHER_DB, FIRESTORE_PROJECT, MISSING_DOCUMENT, RANDOM_ID_REGEX, UNIQUE_RESOURCE_ID, + ENTERPRISE_MODE_ERROR, + TEST_DATABASES, + TEST_DATABASES_W_ENTERPRISE, + FIRESTORE_ENTERPRISE_DB, ) @@ -80,13 +84,71 @@ def cleanup(): operation() -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +def verify_pipeline(query): + """ + This function ensures a pipeline produces the same + results as the query it is derived from + + It can be attached to existing query tests to check both + modalities at the same time + """ + from google.cloud.firestore_v1.base_aggregation import BaseAggregationQuery + + if FIRESTORE_EMULATOR: + pytest.skip("skip pipeline verification on emulator") + + def _clean_results(results): + if isinstance(results, dict): + return {k: _clean_results(v) for k, v in results.items()} + elif isinstance(results, list): + return [_clean_results(r) for r in results] + elif isinstance(results, float) and math.isnan(results): + return "__NAN_VALUE__" + else: + return results + + query_exception = None + query_results = None + try: + try: + if isinstance(query, BaseAggregationQuery): + # aggregation queries return a list of lists of aggregation results + query_results = _clean_results( + list( + itertools.chain.from_iterable( + [[a._to_dict() for a in s] for s in query.get()] + ) + ) + ) + else: + # other qureies return a simple list of results + query_results = _clean_results([s.to_dict() for s in query.get()]) + except Exception as e: + # if we expect the query to fail, capture the exception + query_exception = e + client = query._client + pipeline = client.pipeline().create_from(query) + if query_exception: + # ensure that the pipeline uses same error as query + with pytest.raises(query_exception.__class__): + pipeline.execute() + else: + # ensure results match query + pipeline_results = _clean_results([s.data() for s in pipeline.execute()]) + assert query_results == pipeline_results + except FailedPrecondition as e: + # if testing against a non-enterprise db, skip this check + if ENTERPRISE_MODE_ERROR not in e.message: + raise e + + +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_collections(client, database): collections = list(client.collections()) assert isinstance(collections, list) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB]) +@pytest.mark.parametrize("database", TEST_DATABASES) def test_collections_w_import(database): from google.cloud import firestore @@ -103,7 +165,7 @@ def test_collections_w_import(database): FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) @pytest.mark.parametrize("method", ["stream", "get"]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_collection_stream_or_get_w_no_explain_options(database, query_docs, method): from google.cloud.firestore_v1.query_profile import QueryExplainError @@ -125,7 +187,7 @@ def test_collection_stream_or_get_w_no_explain_options(database, query_docs, met FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) @pytest.mark.parametrize("method", ["get", "stream"]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_collection_stream_or_get_w_explain_options_analyze_false( database, method, query_docs ): @@ -163,7 +225,7 @@ def test_collection_stream_or_get_w_explain_options_analyze_false( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) @pytest.mark.parametrize("method", ["get", "stream"]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_collection_stream_or_get_w_explain_options_analyze_true( database, method, query_docs ): @@ -217,7 +279,7 @@ def test_collection_stream_or_get_w_explain_options_analyze_true( assert len(execution_stats.debug_stats) > 0 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_collections_w_read_time(client, cleanup, database): first_collection_id = "doc-create" + UNIQUE_RESOURCE_ID first_document_id = "doc" + UNIQUE_RESOURCE_ID @@ -228,7 +290,6 @@ def test_collections_w_read_time(client, cleanup, database): data = {"status": "new"} write_result = first_document.create(data) read_time = write_result.update_time - num_collections = len(list(client.collections())) second_collection_id = "doc-create" + UNIQUE_RESOURCE_ID + "-2" second_document_id = "doc" + UNIQUE_RESOURCE_ID + "-2" @@ -238,7 +299,6 @@ def test_collections_w_read_time(client, cleanup, database): # Test that listing current collections does have the second id. curr_collections = list(client.collections()) - assert len(curr_collections) > num_collections ids = [collection.id for collection in curr_collections] assert second_collection_id in ids assert first_collection_id in ids @@ -250,7 +310,7 @@ def test_collections_w_read_time(client, cleanup, database): assert first_collection_id in ids -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_create_document(client, cleanup, database): now = datetime.datetime.now(tz=datetime.timezone.utc) collection_id = "doc-create" + UNIQUE_RESOURCE_ID @@ -295,7 +355,7 @@ def test_create_document(client, cleanup, database): assert stored_data == expected_data -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_create_document_w_vector(client, cleanup, database): collection_id = "doc-create" + UNIQUE_RESOURCE_ID document1 = client.document(collection_id, "doc1") @@ -326,7 +386,7 @@ def on_snapshot(docs, changes, read_time): @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @pytest.mark.parametrize( "distance_measure", [ @@ -355,7 +415,7 @@ def test_vector_search_collection(client, database, distance_measure): @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @pytest.mark.parametrize( "distance_measure", [ @@ -384,7 +444,7 @@ def test_vector_search_collection_with_filter(client, database, distance_measure @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_vector_search_collection_with_distance_parameters_euclid(client, database): # Documents and Indexes are a manual step from util/bootstrap_vector_index.py collection_id = "vector_search" @@ -414,7 +474,7 @@ def test_vector_search_collection_with_distance_parameters_euclid(client, databa @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_vector_search_collection_with_distance_parameters_cosine(client, database): # Documents and Indexes are a manual step from util/bootstrap_vector_index.py collection_id = "vector_search" @@ -444,7 +504,7 @@ def test_vector_search_collection_with_distance_parameters_cosine(client, databa @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @pytest.mark.parametrize( "distance_measure", [ @@ -480,7 +540,7 @@ def test_vector_search_collection_group(client, database, distance_measure): DistanceMeasure.COSINE, ], ) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_vector_search_collection_group_with_filter(client, database, distance_measure): # Documents and Indexes are a manual step from util/bootstrap_vector_index.py collection_id = "vector_search" @@ -502,7 +562,7 @@ def test_vector_search_collection_group_with_filter(client, database, distance_m @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_vector_search_collection_group_with_distance_parameters_euclid( client, database ): @@ -534,7 +594,7 @@ def test_vector_search_collection_group_with_distance_parameters_euclid( @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_vector_search_collection_group_with_distance_parameters_cosine( client, database ): @@ -569,7 +629,7 @@ def test_vector_search_collection_group_with_distance_parameters_cosine( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) @pytest.mark.parametrize("method", ["stream", "get"]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_vector_query_stream_or_get_w_no_explain_options(client, database, method): from google.cloud.firestore_v1.query_profile import QueryExplainError @@ -599,7 +659,7 @@ def test_vector_query_stream_or_get_w_no_explain_options(client, database, metho FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) @pytest.mark.parametrize("method", ["stream", "get"]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_vector_query_stream_or_get_w_explain_options_analyze_true( client, database, method ): @@ -668,7 +728,7 @@ def test_vector_query_stream_or_get_w_explain_options_analyze_true( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) @pytest.mark.parametrize("method", ["stream", "get"]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_vector_query_stream_or_get_w_explain_options_analyze_false( client, database, method ): @@ -715,7 +775,7 @@ def test_vector_query_stream_or_get_w_explain_options_analyze_false( explain_metrics.execution_stats -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_create_document_w_subcollection(client, cleanup, database): collection_id = "doc-create-sub" + UNIQUE_RESOURCE_ID document_id = "doc" + UNIQUE_RESOURCE_ID @@ -741,7 +801,7 @@ def assert_timestamp_less(timestamp_pb1, timestamp_pb2): assert timestamp_pb1 < timestamp_pb2 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_document_collections_w_read_time(client, cleanup, database): collection_id = "doc-create-sub" + UNIQUE_RESOURCE_ID document_id = "doc" + UNIQUE_RESOURCE_ID @@ -777,7 +837,7 @@ def test_document_collections_w_read_time(client, cleanup, database): ) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_no_document(client, database): document_id = "no_document" + UNIQUE_RESOURCE_ID document = client.document("abcde", document_id) @@ -785,7 +845,7 @@ def test_no_document(client, database): assert snapshot.to_dict() is None -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_document_set(client, cleanup, database): document_id = "for-set" + UNIQUE_RESOURCE_ID document = client.document("i-did-it", document_id) @@ -815,7 +875,7 @@ def test_document_set(client, cleanup, database): assert snapshot2.update_time == write_result2.update_time -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_document_integer_field(client, cleanup, database): document_id = "for-set" + UNIQUE_RESOURCE_ID document = client.document("i-did-it", document_id) @@ -832,7 +892,7 @@ def test_document_integer_field(client, cleanup, database): assert snapshot.to_dict() == expected -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_document_set_merge(client, cleanup, database): document_id = "for-set" + UNIQUE_RESOURCE_ID document = client.document("i-did-it", document_id) @@ -865,7 +925,7 @@ def test_document_set_merge(client, cleanup, database): assert snapshot2.update_time == write_result2.update_time -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_document_set_w_int_field(client, cleanup, database): document_id = "set-int-key" + UNIQUE_RESOURCE_ID document = client.document("i-did-it", document_id) @@ -889,7 +949,7 @@ def test_document_set_w_int_field(client, cleanup, database): assert snapshot1.to_dict() == data -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_document_update_w_int_field(client, cleanup, database): # Attempt to reproduce #5489. document_id = "update-int-key" + UNIQUE_RESOURCE_ID @@ -917,7 +977,7 @@ def test_document_update_w_int_field(client, cleanup, database): @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Internal Issue b/137867104") -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_update_document(client, cleanup, database): document_id = "for-update" + UNIQUE_RESOURCE_ID document = client.document("made", document_id) @@ -989,7 +1049,7 @@ def check_snapshot(snapshot, document, data, write_result): assert snapshot.update_time == write_result.update_time -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_document_get(client, cleanup, database): now = datetime.datetime.now(tz=datetime.timezone.utc) document_id = "for-get" + UNIQUE_RESOURCE_ID @@ -1015,7 +1075,7 @@ def test_document_get(client, cleanup, database): check_snapshot(snapshot, document, data, write_result) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_document_delete(client, cleanup, database): document_id = "deleted" + UNIQUE_RESOURCE_ID document = client.document("here-to-be", document_id) @@ -1052,7 +1112,7 @@ def test_document_delete(client, cleanup, database): assert_timestamp_less(delete_time3, delete_time4) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_collection_add(client, cleanup, database): # TODO(microgen): list_documents is returning a generator, not a list. # Consider if this is desired. Also, Document isn't hashable. @@ -1141,7 +1201,7 @@ def test_collection_add(client, cleanup, database): assert set(collection3.list_documents()) == {document_ref5} -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_list_collections_with_read_time(client, cleanup, database): # TODO(microgen): list_documents is returning a generator, not a list. # Consider if this is desired. Also, Document isn't hashable. @@ -1166,7 +1226,7 @@ def test_list_collections_with_read_time(client, cleanup, database): } -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_unicode_doc(client, cleanup, database): collection_id = "coll-unicode" + UNIQUE_RESOURCE_ID collection = client.collection(collection_id) @@ -1233,7 +1293,7 @@ def query(collection): return collection.where(filter=FieldFilter("a", "==", 1)) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_query_stream_legacy_where(query_docs, database): """Assert the legacy code still works and returns value""" collection, stored, allowed_vals = query_docs @@ -1247,9 +1307,10 @@ def test_query_stream_legacy_where(query_docs, database): for key, value in values.items(): assert stored[key] == value assert value["a"] == 1 + verify_pipeline(query) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_query_stream_w_simple_field_eq_op(query_docs, database): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("a", "==", 1)) @@ -1258,9 +1319,10 @@ def test_query_stream_w_simple_field_eq_op(query_docs, database): for key, value in values.items(): assert stored[key] == value assert value["a"] == 1 + verify_pipeline(query) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_query_stream_w_simple_field_array_contains_op(query_docs, database): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("c", "array_contains", 1)) @@ -1269,9 +1331,10 @@ def test_query_stream_w_simple_field_array_contains_op(query_docs, database): for key, value in values.items(): assert stored[key] == value assert value["a"] == 1 + verify_pipeline(query) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_query_stream_w_simple_field_in_op(query_docs, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) @@ -1281,9 +1344,10 @@ def test_query_stream_w_simple_field_in_op(query_docs, database): for key, value in values.items(): assert stored[key] == value assert value["a"] == 1 + verify_pipeline(query) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_query_stream_w_not_eq_op(query_docs, database): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("stats.sum", "!=", 4)) @@ -1303,9 +1367,10 @@ def test_query_stream_w_not_eq_op(query_docs, database): ] ) assert expected_ab_pairs == ab_pairs2 + verify_pipeline(query) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_query_stream_w_simple_not_in_op(query_docs, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) @@ -1315,9 +1380,10 @@ def test_query_stream_w_simple_not_in_op(query_docs, database): values = {snapshot.id: snapshot.to_dict() for snapshot in query.stream()} assert len(values) == 22 + verify_pipeline(query) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_query_stream_w_simple_field_array_contains_any_op(query_docs, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) @@ -1329,9 +1395,10 @@ def test_query_stream_w_simple_field_array_contains_any_op(query_docs, database) for key, value in values.items(): assert stored[key] == value assert value["a"] == 1 + verify_pipeline(query) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_query_stream_w_order_by(query_docs, database): collection, stored, allowed_vals = query_docs query = collection.order_by("b", direction=firestore.Query.DESCENDING) @@ -1343,9 +1410,10 @@ def test_query_stream_w_order_by(query_docs, database): b_vals.append(value["b"]) # Make sure the ``b``-values are in DESCENDING order. assert sorted(b_vals, reverse=True) == b_vals + verify_pipeline(query) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_query_stream_w_field_path(query_docs, database): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("stats.sum", ">", 4)) @@ -1365,9 +1433,10 @@ def test_query_stream_w_field_path(query_docs, database): ] ) assert expected_ab_pairs == ab_pairs2 + verify_pipeline(query) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_query_stream_w_start_end_cursor(query_docs, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) @@ -1383,16 +1452,17 @@ def test_query_stream_w_start_end_cursor(query_docs, database): assert value["a"] == num_vals - 2 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_query_stream_wo_results(query_docs, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) query = collection.where(filter=FieldFilter("b", "==", num_vals + 100)) values = list(query.stream()) assert len(values) == 0 + verify_pipeline(query) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_query_stream_w_projection(query_docs, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) @@ -1409,7 +1479,7 @@ def test_query_stream_w_projection(query_docs, database): assert expected == value -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_query_stream_w_multiple_filters(query_docs, database): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("stats.product", ">", 5)).where( @@ -1427,9 +1497,10 @@ def test_query_stream_w_multiple_filters(query_docs, database): assert stored[key] == value pair = (value["a"], value["b"]) assert pair in matching_pairs + verify_pipeline(query) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_query_stream_w_offset(query_docs, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) @@ -1443,13 +1514,14 @@ def test_query_stream_w_offset(query_docs, database): for key, value in values.items(): assert stored[key] == value assert value["b"] == 2 + verify_pipeline(query) @pytest.mark.skipif( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) @pytest.mark.parametrize("method", ["stream", "get"]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_query_stream_or_get_w_no_explain_options(query_docs, database, method): from google.cloud.firestore_v1.query_profile import QueryExplainError @@ -1465,13 +1537,14 @@ def test_query_stream_or_get_w_no_explain_options(query_docs, database, method): # is called with pytest.raises(QueryExplainError, match="explain_options not set on query"): results.get_explain_metrics() + verify_pipeline(query) @pytest.mark.skipif( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) @pytest.mark.parametrize("method", ["stream", "get"]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_query_stream_or_get_w_explain_options_analyze_true( query_docs, database, method ): @@ -1531,7 +1604,7 @@ def test_query_stream_or_get_w_explain_options_analyze_true( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) @pytest.mark.parametrize("method", ["stream", "get"]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_query_stream_or_get_w_explain_options_analyze_false( query_docs, database, method ): @@ -1571,7 +1644,125 @@ def test_query_stream_or_get_w_explain_options_analyze_false( explain_metrics.execution_stats -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.skipif( + FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." +) +@pytest.mark.parametrize("method", ["execute", "stream"]) +@pytest.mark.parametrize("database", [FIRESTORE_ENTERPRISE_DB], indirect=True) +def test_pipeline_explain_options_explain_mode(database, method, query_docs): + """Explain currently not supported by backend. Expect error""" + from google.cloud.firestore_v1.query_profile import ( + PipelineExplainOptions, + ) + + collection, _, _ = query_docs + client = collection._client + query = collection.where(filter=FieldFilter("a", "==", 1)) + pipeline = client.pipeline().create_from(query) + + # Tests either `execute()` or `stream()`. + method_under_test = getattr(pipeline, method) + explain_options = PipelineExplainOptions(mode="explain") + + # for now, expect error on explain mode + with pytest.raises(InvalidArgument) as e: + results = method_under_test(explain_options=explain_options) + list(results) + assert "Explain execution mode is not supported" in str(e) + + +@pytest.mark.skipif( + FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." +) +@pytest.mark.parametrize("method", ["execute", "stream"]) +@pytest.mark.parametrize("database", [FIRESTORE_ENTERPRISE_DB], indirect=True) +def test_pipeline_explain_options_analyze_mode(database, method, query_docs): + from google.cloud.firestore_v1.query_profile import ( + PipelineExplainOptions, + ExplainStats, + QueryExplainError, + ) + from google.cloud.firestore_v1.types.explain_stats import ( + ExplainStats as ExplainStats_pb, + ) + + collection, _, allowed_vals = query_docs + client = collection._client + query = collection.where(filter=FieldFilter("a", "==", 1)) + pipeline = client.pipeline().create_from(query) + + # Tests either `execute()` or `stream()`. + method_under_test = getattr(pipeline, method) + results = method_under_test(explain_options=PipelineExplainOptions()) + + if method == "stream": + # check for error accessing explain stats before iterating + with pytest.raises( + QueryExplainError, + match="explain_stats not available until query is complete", + ): + results.explain_stats + + # Finish iterating results, and explain_stats should be available. + results_list = list(results) + num_results = len(results_list) + assert num_results == len(allowed_vals) + + # Verify explain_stats. + explain_stats = results.explain_stats + assert isinstance(explain_stats, ExplainStats) + + assert isinstance(explain_stats.get_raw(), ExplainStats_pb) + text_stats = explain_stats.get_text() + assert "Execution:" in text_stats + + +@pytest.mark.skipif( + FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." +) +@pytest.mark.parametrize("method", ["execute", "stream"]) +@pytest.mark.parametrize("database", [FIRESTORE_ENTERPRISE_DB], indirect=True) +def test_pipeline_explain_options_using_additional_options( + database, method, query_docs +): + """additional_options field allows passing in arbitrary options. Test with explain_options""" + from google.cloud.firestore_v1.query_profile import ( + PipelineExplainOptions, + ExplainStats, + ) + from google.cloud.firestore_v1.types.explain_stats import ( + ExplainStats as ExplainStats_pb, + ) + + collection, _, allowed_vals = query_docs + client = collection._client + query = collection.where(filter=FieldFilter("a", "==", 1)) + pipeline = client.pipeline().create_from(query) + + # Tests either `execute()` or `stream()`. + method_under_test = getattr(pipeline, method) + + encoded_options = {"explain_options": PipelineExplainOptions()._to_value()} + + results = method_under_test( + explain_options=mock.Mock(), additional_options=encoded_options + ) + + # Finish iterating results, and explain_stats should be available./w_read + results_list = list(results) + num_results = len(results_list) + assert num_results == len(allowed_vals) + + # Verify explain_stats. + explain_stats = results.explain_stats + assert isinstance(explain_stats, ExplainStats) + + assert isinstance(explain_stats.get_raw(), ExplainStats_pb) + text_stats = explain_stats.get_text() + assert "Execution:" in text_stats + + +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_query_stream_w_read_time(query_docs, cleanup, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) @@ -1609,7 +1800,53 @@ def test_query_stream_w_read_time(query_docs, cleanup, database): assert new_values[new_ref.id] == new_data -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.skipif( + FIRESTORE_EMULATOR, reason="Pipeline query not supported in emulator." +) +@pytest.mark.parametrize("database", [FIRESTORE_ENTERPRISE_DB], indirect=True) +def test_pipeline_w_read_time(query_docs, cleanup, database): + collection, stored, allowed_vals = query_docs + num_vals = len(allowed_vals) + + # Find a read_time before adding the new document. + snapshots = collection.get() + read_time = snapshots[0].read_time + + new_data = { + "a": 9000, + "b": 1, + } + _, new_ref = collection.add(new_data) + # Add to clean-up. + cleanup(new_ref.delete) + stored[new_ref.id] = new_data + + client = collection._client + query = collection.where(filter=FieldFilter("b", "==", 1)) + pipeline = client.pipeline().create_from(query) + + # new query should have new_data + new_results = list(pipeline.stream()) + new_values = {result.ref.id: result.data() for result in new_results} + assert len(new_values) == num_vals + 1 + assert new_ref.id in new_values + assert new_values[new_ref.id] == new_data + + # query with read_time should not have new)data + results = list(pipeline.stream(read_time=read_time)) + + values = {result.ref.id: result.data() for result in results} + + assert len(values) == num_vals + assert new_ref.id not in values + for key, value in values.items(): + assert stored[key] == value + assert value["b"] == 1 + assert value["a"] != 9000 + assert key != new_ref.id + + +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_query_with_order_dot_key(client, cleanup, database): db = client collection_id = "collek" + UNIQUE_RESOURCE_ID @@ -1622,15 +1859,16 @@ def test_query_with_order_dot_key(client, cleanup, database): query = collection.order_by("wordcount.page1").limit(3) data = [doc.to_dict()["wordcount"]["page1"] for doc in query.stream()] assert [100, 110, 120] == data - for snapshot in collection.order_by("wordcount.page1").limit(3).stream(): + query2 = collection.order_by("wordcount.page1").limit(3) + for snapshot in query2.stream(): last_value = snapshot.get("wordcount.page1") cursor_with_nested_keys = {"wordcount": {"page1": last_value}} - found = list( + query3 = ( collection.order_by("wordcount.page1") .start_after(cursor_with_nested_keys) .limit(3) - .stream() ) + found = list(query3.stream()) found_data = [ {"count": 30, "wordcount": {"page1": 130}}, {"count": 40, "wordcount": {"page1": 140}}, @@ -1638,16 +1876,16 @@ def test_query_with_order_dot_key(client, cleanup, database): ] assert found_data == [snap.to_dict() for snap in found] cursor_with_dotted_paths = {"wordcount.page1": last_value} - cursor_with_key_data = list( + query4 = ( collection.order_by("wordcount.page1") .start_after(cursor_with_dotted_paths) .limit(3) - .stream() ) + cursor_with_key_data = list(query4.stream()) assert found_data == [snap.to_dict() for snap in cursor_with_key_data] -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_query_unary(client, cleanup, database): collection_name = "unary" + UNIQUE_RESOURCE_ID collection = client.collection(collection_name) @@ -1673,6 +1911,7 @@ def test_query_unary(client, cleanup, database): snapshot0 = values0[0] assert snapshot0.reference._path == document0._path assert snapshot0.to_dict() == {field_name: None} + verify_pipeline(query0) # 1. Query for a NAN. query1 = collection.where(filter=FieldFilter(field_name, "==", nan_val)) @@ -1683,6 +1922,7 @@ def test_query_unary(client, cleanup, database): data1 = snapshot1.to_dict() assert len(data1) == 1 assert math.isnan(data1[field_name]) + verify_pipeline(query1) # 2. Query for not null query2 = collection.where(filter=FieldFilter(field_name, "!=", None)) @@ -1702,7 +1942,7 @@ def test_query_unary(client, cleanup, database): assert snapshot3.to_dict() == {field_name: 123} -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_collection_group_queries(client, cleanup, database): collection_group = "b" + UNIQUE_RESOURCE_ID @@ -1732,10 +1972,11 @@ def test_collection_group_queries(client, cleanup, database): snapshots = list(query.stream()) found = [snapshot.id for snapshot in snapshots] expected = ["cg-doc1", "cg-doc2", "cg-doc3", "cg-doc4", "cg-doc5"] - assert found == expected + assert set(found) == set(expected) + verify_pipeline(query) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_collection_group_queries_startat_endat(client, cleanup, database): collection_group = "b" + UNIQUE_RESOURCE_ID @@ -1778,7 +2019,7 @@ def test_collection_group_queries_startat_endat(client, cleanup, database): assert found == set(["cg-doc2"]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_collection_group_queries_filters(client, cleanup, database): collection_group = "b" + UNIQUE_RESOURCE_ID @@ -1821,6 +2062,7 @@ def test_collection_group_queries_filters(client, cleanup, database): snapshots = list(query.stream()) found = set(snapshot.id for snapshot in snapshots) assert found == set(["cg-doc2", "cg-doc3", "cg-doc4"]) + verify_pipeline(query) query = ( client.collection_group(collection_group) @@ -1842,12 +2084,13 @@ def test_collection_group_queries_filters(client, cleanup, database): snapshots = list(query.stream()) found = set(snapshot.id for snapshot in snapshots) assert found == set(["cg-doc2"]) + verify_pipeline(query) @pytest.mark.skipif( FIRESTORE_EMULATOR, reason="PartitionQuery not implemented in emulator" ) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_partition_query_no_partitions(client, cleanup, database): collection_group = "b" + UNIQUE_RESOURCE_ID @@ -1882,7 +2125,7 @@ def test_partition_query_no_partitions(client, cleanup, database): @pytest.mark.skipif( FIRESTORE_EMULATOR, reason="PartitionQuery not implemented in emulator" ) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_partition_query(client, cleanup, database): collection_group = "b" + UNIQUE_RESOURCE_ID n_docs = 128 * 2 + 127 # Minimum partition size is 128 @@ -1910,7 +2153,7 @@ def test_partition_query(client, cleanup, database): @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Internal Issue b/137865992") -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_get_all(client, cleanup, database): collection_name = "get-all" + UNIQUE_RESOURCE_ID @@ -1986,7 +2229,7 @@ def test_get_all(client, cleanup, database): check_snapshot(snapshot3, document3, data3, write_result3) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_batch(client, cleanup, database): collection_name = "batch" + UNIQUE_RESOURCE_ID @@ -2032,7 +2275,7 @@ def test_batch(client, cleanup, database): assert not document3.get().exists -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_live_bulk_writer(client, cleanup, database): from google.cloud.firestore_v1.bulk_writer import BulkWriter from google.cloud.firestore_v1.client import Client @@ -2056,7 +2299,7 @@ def test_live_bulk_writer(client, cleanup, database): assert len(col.get()) == 50 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_watch_document(client, cleanup, database): db = client collection_ref = db.collection("wd-users" + UNIQUE_RESOURCE_ID) @@ -2093,7 +2336,7 @@ def on_snapshot(docs, changes, read_time): ) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_watch_collection(client, cleanup, database): db = client collection_ref = db.collection("wc-users" + UNIQUE_RESOURCE_ID) @@ -2130,7 +2373,7 @@ def on_snapshot(docs, changes, read_time): ) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_watch_query(client, cleanup, database): db = client collection_ref = db.collection("wq-users" + UNIQUE_RESOURCE_ID) @@ -2148,9 +2391,8 @@ def on_snapshot(docs, changes, read_time): on_snapshot.called_count += 1 # A snapshot should return the same thing as if a query ran now. - query_ran = collection_ref.where( - filter=FieldFilter("first", "==", "Ada") - ).stream() + query_ran_query = collection_ref.where(filter=FieldFilter("first", "==", "Ada")) + query_ran = query_ran_query.stream() assert len(docs) == len([i for i in query_ran]) on_snapshot.called_count = 0 @@ -2172,7 +2414,7 @@ def on_snapshot(docs, changes, read_time): ) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_array_union(client, cleanup, database): doc_ref = client.document("gcp-7523", "test-document") cleanup(doc_ref.delete) @@ -2319,7 +2561,7 @@ def _do_recursive_delete(client, bulk_writer, empty_philosophers=False): ), f"Snapshot at Socrates{path} should have been deleted" -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_recursive_delete_parallelized(client, cleanup, database): from google.cloud.firestore_v1.bulk_writer import BulkWriterOptions, SendMode @@ -2327,7 +2569,7 @@ def test_recursive_delete_parallelized(client, cleanup, database): _do_recursive_delete(client, bw) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_recursive_delete_serialized(client, cleanup, database): from google.cloud.firestore_v1.bulk_writer import BulkWriterOptions, SendMode @@ -2335,7 +2577,7 @@ def test_recursive_delete_serialized(client, cleanup, database): _do_recursive_delete(client, bw) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_recursive_delete_parallelized_empty(client, cleanup, database): from google.cloud.firestore_v1.bulk_writer import BulkWriterOptions, SendMode @@ -2343,7 +2585,7 @@ def test_recursive_delete_parallelized_empty(client, cleanup, database): _do_recursive_delete(client, bw, empty_philosophers=True) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_recursive_delete_serialized_empty(client, cleanup, database): from google.cloud.firestore_v1.bulk_writer import BulkWriterOptions, SendMode @@ -2351,12 +2593,13 @@ def test_recursive_delete_serialized_empty(client, cleanup, database): _do_recursive_delete(client, bw, empty_philosophers=True) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_recursive_query(client, cleanup, database): col_id: str = f"philosophers-recursive-query{UNIQUE_RESOURCE_ID}" _persist_documents(client, col_id, philosophers_data_set, cleanup) - ids = [doc.id for doc in client.collection_group(col_id).recursive().get()] + query = client.collection_group(col_id).recursive() + ids = [doc.id for doc in query.get()] expected_ids = [ # Aristotle doc and subdocs @@ -2390,14 +2633,15 @@ def test_recursive_query(client, cleanup, database): assert ids[index] == expected_ids[index], error_msg -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_nested_recursive_query(client, cleanup, database): col_id: str = f"philosophers-nested-recursive-query{UNIQUE_RESOURCE_ID}" _persist_documents(client, col_id, philosophers_data_set, cleanup) collection_ref = client.collection(col_id) aristotle = collection_ref.document("Aristotle") - ids = [doc.id for doc in aristotle.collection("pets").recursive().get()] + query = aristotle.collection("pets").recursive() + ids = [doc.id for doc in query.get()] expected_ids = [ # Aristotle pets @@ -2414,7 +2658,7 @@ def test_nested_recursive_query(client, cleanup, database): assert ids[index] == expected_ids[index], error_msg -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_chunked_query(client, cleanup, database): col = client.collection(f"chunked-test{UNIQUE_RESOURCE_ID}") for index in range(10): @@ -2429,7 +2673,7 @@ def test_chunked_query(client, cleanup, database): assert len(next(iter)) == 1 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_chunked_query_smaller_limit(client, cleanup, database): col = client.collection(f"chunked-test-smaller-limit{UNIQUE_RESOURCE_ID}") for index in range(10): @@ -2441,7 +2685,7 @@ def test_chunked_query_smaller_limit(client, cleanup, database): assert len(next(iter)) == 5 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_chunked_and_recursive(client, cleanup, database): col_id = f"chunked-recursive-test{UNIQUE_RESOURCE_ID}" documents = [ @@ -2490,7 +2734,7 @@ def test_chunked_and_recursive(client, cleanup, database): assert [doc.id for doc in next(iter)] == page_3_ids -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_watch_query_order(client, cleanup, database): db = client collection_ref = db.collection("users") @@ -2566,7 +2810,7 @@ def on_snapshot(docs, changes, read_time): ) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_repro_429(client, cleanup, database): # See: https://github.com/googleapis/python-firestore/issues/429 now = datetime.datetime.now(tz=datetime.timezone.utc) @@ -2592,9 +2836,10 @@ def test_repro_429(client, cleanup, database): for snapshot in query2.stream(): print(f"id: {snapshot.id}") + verify_pipeline(query) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_repro_391(client, cleanup, database): # See: https://github.com/googleapis/python-firestore/issues/391 now = datetime.datetime.now(tz=datetime.timezone.utc) @@ -2609,7 +2854,7 @@ def test_repro_391(client, cleanup, database): assert len(set(collection.stream())) == len(document_ids) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_count_query_get_default_alias(query, database): count_query = query.count() result = count_query.get() @@ -2618,7 +2863,7 @@ def test_count_query_get_default_alias(query, database): assert r.alias == "field_1" -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_count_query_get_with_alias(query, database): count_query = query.count(alias="total") result = count_query.get() @@ -2627,7 +2872,7 @@ def test_count_query_get_with_alias(query, database): assert r.alias == "total" -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_count_query_get_with_limit(query, database): # count without limit count_query = query.count(alias="total") @@ -2647,7 +2892,7 @@ def test_count_query_get_with_limit(query, database): assert r.value == 2 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_count_query_get_multiple_aggregations(query, database): count_query = query.count(alias="total").count(alias="all") @@ -2662,7 +2907,7 @@ def test_count_query_get_multiple_aggregations(query, database): assert found_alias == set(expected_aliases) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_count_query_get_multiple_aggregations_duplicated_alias(query, database): count_query = query.count(alias="total").count(alias="total") @@ -2672,7 +2917,7 @@ def test_count_query_get_multiple_aggregations_duplicated_alias(query, database) assert "Aggregation aliases contain duplicate alias" in exc_info.value.message -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_count_query_get_empty_aggregation(query, database): from google.cloud.firestore_v1.aggregation import AggregationQuery @@ -2684,7 +2929,7 @@ def test_count_query_get_empty_aggregation(query, database): assert "Aggregations can not be empty" in exc_info.value.message -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_count_query_stream_default_alias(query, database): count_query = query.count() for result in count_query.stream(): @@ -2692,7 +2937,7 @@ def test_count_query_stream_default_alias(query, database): assert aggregation_result.alias == "field_1" -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_count_query_stream_with_alias(query, database): count_query = query.count(alias="total") for result in count_query.stream(): @@ -2700,7 +2945,7 @@ def test_count_query_stream_with_alias(query, database): assert aggregation_result.alias == "total" -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_count_query_stream_with_limit(query, database): # count without limit count_query = query.count(alias="total") @@ -2718,7 +2963,7 @@ def test_count_query_stream_with_limit(query, database): assert aggregation_result.value == 2 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_count_query_stream_multiple_aggregations(query, database): count_query = query.count(alias="total").count(alias="all") @@ -2727,7 +2972,7 @@ def test_count_query_stream_multiple_aggregations(query, database): assert aggregation_result.alias in ["total", "all"] -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_count_query_stream_multiple_aggregations_duplicated_alias(query, database): count_query = query.count(alias="total").count(alias="total") @@ -2738,7 +2983,7 @@ def test_count_query_stream_multiple_aggregations_duplicated_alias(query, databa assert "Aggregation aliases contain duplicate alias" in exc_info.value.message -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_count_query_stream_empty_aggregation(query, database): from google.cloud.firestore_v1.aggregation import AggregationQuery @@ -2751,7 +2996,7 @@ def test_count_query_stream_empty_aggregation(query, database): assert "Aggregations can not be empty" in exc_info.value.message -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_count_query_with_start_at(query, database): """ Ensure that count aggregation queries work when chained with a start_at @@ -2770,7 +3015,7 @@ def test_count_query_with_start_at(query, database): assert aggregation_result.value == expected_count -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_sum_query_get_default_alias(collection, database): sum_query = collection.sum("stats.product") result = sum_query.get() @@ -2780,7 +3025,7 @@ def test_sum_query_get_default_alias(collection, database): assert r.value == 100 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_sum_query_get_with_alias(collection, database): sum_query = collection.sum("stats.product", alias="total") result = sum_query.get() @@ -2790,7 +3035,7 @@ def test_sum_query_get_with_alias(collection, database): assert r.value == 100 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_sum_query_get_with_limit(collection, database): # sum without limit sum_query = collection.sum("stats.product", alias="total") @@ -2811,7 +3056,7 @@ def test_sum_query_get_with_limit(collection, database): assert r.value == 5 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_sum_query_get_multiple_aggregations(collection, database): sum_query = collection.sum("stats.product", alias="total").sum( "stats.product", alias="all" @@ -2828,7 +3073,7 @@ def test_sum_query_get_multiple_aggregations(collection, database): assert found_alias == set(expected_aliases) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_sum_query_stream_default_alias(collection, database): sum_query = collection.sum("stats.product") for result in sum_query.stream(): @@ -2837,7 +3082,7 @@ def test_sum_query_stream_default_alias(collection, database): assert aggregation_result.value == 100 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_sum_query_stream_with_alias(collection, database): sum_query = collection.sum("stats.product", alias="total") for result in sum_query.stream(): @@ -2846,7 +3091,7 @@ def test_sum_query_stream_with_alias(collection, database): assert aggregation_result.value == 100 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_sum_query_stream_with_limit(collection, database): # sum without limit sum_query = collection.sum("stats.product", alias="total") @@ -2864,7 +3109,7 @@ def test_sum_query_stream_with_limit(collection, database): assert aggregation_result.value == 5 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_sum_query_stream_multiple_aggregations(collection, database): sum_query = collection.sum("stats.product", alias="total").sum( "stats.product", alias="all" @@ -2878,7 +3123,7 @@ def test_sum_query_stream_multiple_aggregations(collection, database): # tests for issue reported in b/306241058 # we will skip test in client for now, until backend fix is implemented @pytest.mark.skip(reason="backend fix required") -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_sum_query_with_start_at(query, database): """ Ensure that sum aggregation queries work when chained with a start_at @@ -2896,7 +3141,7 @@ def test_sum_query_with_start_at(query, database): assert sum_result[0].value == expected_sum -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_avg_query_get_default_alias(collection, database): avg_query = collection.avg("stats.product") result = avg_query.get() @@ -2907,7 +3152,7 @@ def test_avg_query_get_default_alias(collection, database): assert isinstance(r.value, float) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_avg_query_get_with_alias(collection, database): avg_query = collection.avg("stats.product", alias="total") result = avg_query.get() @@ -2917,7 +3162,7 @@ def test_avg_query_get_with_alias(collection, database): assert r.value == 4 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_avg_query_get_with_limit(collection, database): # avg without limit avg_query = collection.avg("stats.product", alias="total") @@ -2939,7 +3184,7 @@ def test_avg_query_get_with_limit(collection, database): assert isinstance(r.value, float) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_avg_query_get_multiple_aggregations(collection, database): avg_query = collection.avg("stats.product", alias="total").avg( "stats.product", alias="all" @@ -2956,7 +3201,7 @@ def test_avg_query_get_multiple_aggregations(collection, database): assert found_alias == set(expected_aliases) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_avg_query_stream_default_alias(collection, database): avg_query = collection.avg("stats.product") for result in avg_query.stream(): @@ -2965,7 +3210,7 @@ def test_avg_query_stream_default_alias(collection, database): assert aggregation_result.value == 4 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_avg_query_stream_with_alias(collection, database): avg_query = collection.avg("stats.product", alias="total") for result in avg_query.stream(): @@ -2974,7 +3219,7 @@ def test_avg_query_stream_with_alias(collection, database): assert aggregation_result.value == 4 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_avg_query_stream_with_limit(collection, database): # avg without limit avg_query = collection.avg("stats.product", alias="total") @@ -2992,7 +3237,7 @@ def test_avg_query_stream_with_limit(collection, database): assert aggregation_result.value == 5 / 12 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_avg_query_stream_multiple_aggregations(collection, database): avg_query = collection.avg("stats.product", alias="total").avg( "stats.product", alias="all" @@ -3006,7 +3251,7 @@ def test_avg_query_stream_multiple_aggregations(collection, database): # tests for issue reported in b/306241058 # we will skip test in client for now, until backend fix is implemented @pytest.mark.skip(reason="backend fix required") -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_avg_query_with_start_at(query, database): """ Ensure that avg aggregation queries work when chained with a start_at @@ -3030,7 +3275,7 @@ def test_avg_query_with_start_at(query, database): FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) @pytest.mark.parametrize("method", ["stream", "get"]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_aggregation_query_stream_or_get_w_no_explain_options(query, database, method): # Because all aggregation methods end up calling AggregationQuery.get() or # AggregationQuery.stream(), only use count() for testing here. @@ -3056,7 +3301,7 @@ def test_aggregation_query_stream_or_get_w_no_explain_options(query, database, m FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) @pytest.mark.parametrize("method", ["stream", "get"]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_aggregation_query_stream_or_get_w_explain_options_analyze_true( query, database, method ): @@ -3120,7 +3365,7 @@ def test_aggregation_query_stream_or_get_w_explain_options_analyze_true( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) @pytest.mark.parametrize("method", ["stream", "get"]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_aggregation_query_stream_or_get_w_explain_options_analyze_false( query, database, method ): @@ -3160,7 +3405,7 @@ def test_aggregation_query_stream_or_get_w_explain_options_analyze_false( explain_metrics.execution_stats -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_query_with_and_composite_filter(collection, database): and_filter = And( filters=[ @@ -3173,9 +3418,10 @@ def test_query_with_and_composite_filter(collection, database): for result in query.stream(): assert result.get("stats.product") > 5 assert result.get("stats.product") < 10 + verify_pipeline(query) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_query_with_or_composite_filter(collection, database): or_filter = Or( filters=[ @@ -3196,14 +3442,20 @@ def test_query_with_or_composite_filter(collection, database): assert gt_5 > 0 assert lt_10 > 0 + verify_pipeline(query) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @pytest.mark.parametrize( "aggregation_type,expected_value", [("count", 5), ("sum", 100), ("avg", 4.0)] ) def test_aggregation_queries_with_read_time( - collection, query, cleanup, database, aggregation_type, expected_value + collection, + query, + cleanup, + database, + aggregation_type, + expected_value, ): """ Ensure that all aggregation queries work when read_time is passed into @@ -3238,9 +3490,10 @@ def test_aggregation_queries_with_read_time( assert len(old_result) == 1 for r in old_result[0]: assert r.value == expected_value + verify_pipeline(aggregation_query) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_query_with_complex_composite_filter(collection, database): field_filter = FieldFilter("b", "==", 0) or_filter = Or( @@ -3261,6 +3514,7 @@ def test_query_with_complex_composite_filter(collection, database): assert sum_0 > 0 assert sum_4 > 0 + verify_pipeline(query) # b == 3 || (stats.sum == 4 && a == 4) comp_filter = Or( @@ -3283,15 +3537,21 @@ def test_query_with_complex_composite_filter(collection, database): assert b_3 is True assert b_not_3 is True + verify_pipeline(query) @pytest.mark.parametrize( "aggregation_type,aggregation_args,expected", [("count", (), 3), ("sum", ("b"), 12), ("avg", ("b"), 4)], ) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_aggregation_query_in_transaction( - client, cleanup, database, aggregation_type, aggregation_args, expected + client, + cleanup, + database, + aggregation_type, + aggregation_args, + expected, ): """ Test creating an aggregation query inside a transaction @@ -3325,13 +3585,14 @@ def in_transaction(transaction): assert len(result[0]) == 1 assert result[0][0].value == expected inner_fn_ran = True + verify_pipeline(aggregation_query) in_transaction(transaction) # make sure we didn't skip assertions in inner function assert inner_fn_ran is True -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_or_query_in_transaction(client, cleanup, database): """ Test running or query inside a transaction. Should pass transaction id along with request @@ -3370,13 +3631,14 @@ def in_transaction(transaction): result[0].get("b") == 2 and result[1].get("b") == 1 ) inner_fn_ran = True + verify_pipeline(query) in_transaction(transaction) # make sure we didn't skip assertions in inner function assert inner_fn_ran is True -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_transaction_w_uuid(client, cleanup, database): """ https://github.com/googleapis/python-firestore/issues/1012 @@ -3401,7 +3663,7 @@ def update_doc(tx, doc_ref, key, value): @pytest.mark.skipif( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_query_in_transaction_with_explain_options(client, cleanup, database): """ Test query profiling in transactions. @@ -3453,7 +3715,7 @@ def in_transaction(transaction): assert inner_fn_ran is True -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_query_in_transaction_with_read_time(client, cleanup, database): """ Test query profiling in transactions. @@ -3499,7 +3761,7 @@ def in_transaction(transaction): assert inner_fn_ran is True -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_update_w_uuid(client, cleanup, database): """ https://github.com/googleapis/python-firestore/issues/1012 @@ -3518,7 +3780,7 @@ def test_update_w_uuid(client, cleanup, database): @pytest.mark.parametrize("with_rollback,expected", [(True, 2), (False, 3)]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_transaction_rollback(client, cleanup, database, with_rollback, expected): """ Create a document in a transaction that is rolled back diff --git a/tests/system/test_system_async.py b/tests/system/test_system_async.py index 945e7cb12..1442e7932 100644 --- a/tests/system/test_system_async.py +++ b/tests/system/test_system_async.py @@ -22,6 +22,7 @@ import google.auth import pytest import pytest_asyncio +import mock from google.api_core import exceptions as core_exceptions from google.api_core import retry_async as retries from google.api_core.exceptions import ( @@ -49,11 +50,14 @@ EMULATOR_CREDS, FIRESTORE_CREDS, FIRESTORE_EMULATOR, - FIRESTORE_OTHER_DB, FIRESTORE_PROJECT, MISSING_DOCUMENT, RANDOM_ID_REGEX, UNIQUE_RESOURCE_ID, + ENTERPRISE_MODE_ERROR, + TEST_DATABASES, + TEST_DATABASES_W_ENTERPRISE, + FIRESTORE_ENTERPRISE_DB, ) RETRIES = retries.AsyncRetry( @@ -160,6 +164,66 @@ async def cleanup(): await operation() +async def verify_pipeline(query): + """ + This function ensures a pipeline produces the same + results as the query it is derived from + + It can be attached to existing query tests to check both + modalities at the same time + """ + from google.cloud.firestore_v1.base_aggregation import BaseAggregationQuery + + if FIRESTORE_EMULATOR: + pytest.skip("skip pipeline verification on emulator") + + def _clean_results(results): + if isinstance(results, dict): + return {k: _clean_results(v) for k, v in results.items()} + elif isinstance(results, list): + return [_clean_results(r) for r in results] + elif isinstance(results, float) and math.isnan(results): + return "__NAN_VALUE__" + else: + return results + + query_exception = None + query_results = None + try: + try: + if isinstance(query, BaseAggregationQuery): + # aggregation queries return a list of lists of aggregation results + query_results = _clean_results( + list( + itertools.chain.from_iterable( + [[a._to_dict() for a in s] for s in await query.get()] + ) + ) + ) + else: + # other qureies return a simple list of results + query_results = _clean_results([s.to_dict() for s in await query.get()]) + except Exception as e: + # if we expect the query to fail, capture the exception + query_exception = e + client = query._client + pipeline = client.pipeline().create_from(query) + if query_exception: + # ensure that the pipeline uses same error as query + with pytest.raises(query_exception.__class__): + await pipeline.execute() + else: + # ensure results match query + pipeline_results = _clean_results( + [s.data() async for s in pipeline.stream()] + ) + assert query_results == pipeline_results + except FailedPrecondition as e: + # if testing against a non-enterprise db, skip this check + if ENTERPRISE_MODE_ERROR not in e.message: + raise e + + @pytest.fixture(scope="module") def event_loop(): """Change event_loop fixture to module level.""" @@ -169,13 +233,13 @@ def event_loop(): loop.close() -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_collections(client, database): collections = [x async for x in client.collections(retry=RETRIES)] assert isinstance(collections, list) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB]) +@pytest.mark.parametrize("database", TEST_DATABASES) async def test_collections_w_import(database): from google.cloud import firestore @@ -188,7 +252,7 @@ async def test_collections_w_import(database): assert isinstance(collections, list) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_create_document(client, cleanup, database): now = datetime.datetime.now(tz=datetime.timezone.utc) collection_id = "doc-create" + UNIQUE_RESOURCE_ID @@ -234,7 +298,7 @@ async def test_create_document(client, cleanup, database): assert stored_data == expected_data -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_collections_w_read_time(client, cleanup, database): first_collection_id = "doc-create" + UNIQUE_RESOURCE_ID first_document_id = "doc" + UNIQUE_RESOURCE_ID @@ -245,7 +309,6 @@ async def test_collections_w_read_time(client, cleanup, database): data = {"status": "new"} write_result = await first_document.create(data) read_time = write_result.update_time - num_collections = len([x async for x in client.collections(retry=RETRIES)]) second_collection_id = "doc-create" + UNIQUE_RESOURCE_ID + "-2" second_document_id = "doc" + UNIQUE_RESOURCE_ID + "-2" @@ -255,7 +318,6 @@ async def test_collections_w_read_time(client, cleanup, database): # Test that listing current collections does have the second id. curr_collections = [x async for x in client.collections(retry=RETRIES)] - assert len(curr_collections) > num_collections ids = [collection.id for collection in curr_collections] assert second_collection_id in ids assert first_collection_id in ids @@ -269,7 +331,7 @@ async def test_collections_w_read_time(client, cleanup, database): assert first_collection_id in ids -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_create_document_w_subcollection(client, cleanup, database): collection_id = "doc-create-sub" + UNIQUE_RESOURCE_ID document_id = "doc" + UNIQUE_RESOURCE_ID @@ -295,7 +357,7 @@ def assert_timestamp_less(timestamp_pb1, timestamp_pb2): assert timestamp_pb1 < timestamp_pb2 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_document_collections_w_read_time(client, cleanup, database): collection_id = "doc-create-sub" + UNIQUE_RESOURCE_ID document_id = "doc" + UNIQUE_RESOURCE_ID @@ -331,7 +393,7 @@ async def test_document_collections_w_read_time(client, cleanup, database): ) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_no_document(client, database): document_id = "no_document" + UNIQUE_RESOURCE_ID document = client.document("abcde", document_id) @@ -339,7 +401,7 @@ async def test_no_document(client, database): assert snapshot.to_dict() is None -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_document_set(client, cleanup, database): document_id = "for-set" + UNIQUE_RESOURCE_ID document = client.document("i-did-it", document_id) @@ -369,7 +431,7 @@ async def test_document_set(client, cleanup, database): assert snapshot2.update_time == write_result2.update_time -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_document_integer_field(client, cleanup, database): document_id = "for-set" + UNIQUE_RESOURCE_ID document = client.document("i-did-it", document_id) @@ -386,7 +448,7 @@ async def test_document_integer_field(client, cleanup, database): assert snapshot.to_dict() == expected -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_document_set_merge(client, cleanup, database): document_id = "for-set" + UNIQUE_RESOURCE_ID document = client.document("i-did-it", document_id) @@ -419,7 +481,7 @@ async def test_document_set_merge(client, cleanup, database): assert snapshot2.update_time == write_result2.update_time -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_document_set_w_int_field(client, cleanup, database): document_id = "set-int-key" + UNIQUE_RESOURCE_ID document = client.document("i-did-it", document_id) @@ -443,7 +505,7 @@ async def test_document_set_w_int_field(client, cleanup, database): assert snapshot1.to_dict() == data -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_document_update_w_int_field(client, cleanup, database): # Attempt to reproduce #5489. document_id = "update-int-key" + UNIQUE_RESOURCE_ID @@ -471,7 +533,7 @@ async def test_document_update_w_int_field(client, cleanup, database): @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @pytest.mark.parametrize( "distance_measure", [ @@ -499,7 +561,7 @@ async def test_vector_search_collection(client, database, distance_measure): @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @pytest.mark.parametrize( "distance_measure", [ @@ -527,7 +589,7 @@ async def test_vector_search_collection_with_filter(client, database, distance_m @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_vector_search_collection_with_distance_parameters_euclid( client, database ): @@ -559,7 +621,7 @@ async def test_vector_search_collection_with_distance_parameters_euclid( @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_vector_search_collection_with_distance_parameters_cosine( client, database ): @@ -591,7 +653,7 @@ async def test_vector_search_collection_with_distance_parameters_cosine( @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @pytest.mark.parametrize( "distance_measure", [ @@ -620,7 +682,7 @@ async def test_vector_search_collection_group(client, database, distance_measure @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @pytest.mark.parametrize( "distance_measure", [ @@ -651,7 +713,7 @@ async def test_vector_search_collection_group_with_filter( @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_vector_search_collection_group_with_distance_parameters_euclid( client, database ): @@ -683,7 +745,7 @@ async def test_vector_search_collection_group_with_distance_parameters_euclid( @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_vector_search_collection_group_with_distance_parameters_cosine( client, database ): @@ -718,7 +780,7 @@ async def test_vector_search_collection_group_with_distance_parameters_cosine( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) @pytest.mark.parametrize("method", ["stream", "get"]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_vector_query_stream_or_get_w_no_explain_options( client, database, method ): @@ -753,7 +815,7 @@ async def test_vector_query_stream_or_get_w_no_explain_options( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) @pytest.mark.parametrize("method", ["stream", "get"]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_vector_query_stream_or_get_w_explain_options_analyze_true( client, query_docs, database, method ): @@ -832,7 +894,7 @@ async def test_vector_query_stream_or_get_w_explain_options_analyze_true( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) @pytest.mark.parametrize("method", ["stream", "get"]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_vector_query_stream_or_get_w_explain_options_analyze_false( client, query_docs, database, method ): @@ -895,7 +957,7 @@ async def test_vector_query_stream_or_get_w_explain_options_analyze_false( @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Internal Issue b/137867104") -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_update_document(client, cleanup, database): document_id = "for-update" + UNIQUE_RESOURCE_ID document = client.document("made", document_id) @@ -968,7 +1030,7 @@ def check_snapshot(snapshot, document, data, write_result): assert snapshot.update_time == write_result.update_time -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_document_get(client, cleanup, database): now = datetime.datetime.now(tz=datetime.timezone.utc) document_id = "for-get" + UNIQUE_RESOURCE_ID @@ -994,7 +1056,7 @@ async def test_document_get(client, cleanup, database): check_snapshot(snapshot, document, data, write_result) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_document_delete(client, cleanup, database): document_id = "deleted" + UNIQUE_RESOURCE_ID document = client.document("here-to-be", document_id) @@ -1031,7 +1093,7 @@ async def test_document_delete(client, cleanup, database): assert_timestamp_less(delete_time3, delete_time4) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_collection_add(client, cleanup, database): # TODO(microgen): list_documents is returning a generator, not a list. # Consider if this is desired. Also, Document isn't hashable. @@ -1133,7 +1195,7 @@ async def test_collection_add(client, cleanup, database): assert set([i async for i in collection3.list_documents()]) == {document_ref5} -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_list_collections_with_read_time(client, cleanup, database): # TODO(microgen): list_documents is returning a generator, not a list. # Consider if this is desired. Also, Document isn't hashable. @@ -1205,7 +1267,7 @@ async def async_query(collection): return collection.where(filter=FieldFilter("a", "==", 1)) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) async def test_query_stream_legacy_where(query_docs, database): """Assert the legacy code still works and returns value, and shows UserWarning""" collection, stored, allowed_vals = query_docs @@ -1219,9 +1281,10 @@ async def test_query_stream_legacy_where(query_docs, database): for key, value in values.items(): assert stored[key] == value assert value["a"] == 1 + await verify_pipeline(query) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) async def test_query_stream_w_simple_field_eq_op(query_docs, database): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("a", "==", 1)) @@ -1230,9 +1293,10 @@ async def test_query_stream_w_simple_field_eq_op(query_docs, database): for key, value in values.items(): assert stored[key] == value assert value["a"] == 1 + await verify_pipeline(query) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) async def test_query_stream_w_simple_field_array_contains_op(query_docs, database): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("c", "array_contains", 1)) @@ -1241,9 +1305,10 @@ async def test_query_stream_w_simple_field_array_contains_op(query_docs, databas for key, value in values.items(): assert stored[key] == value assert value["a"] == 1 + await verify_pipeline(query) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) async def test_query_stream_w_simple_field_in_op(query_docs, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) @@ -1253,9 +1318,10 @@ async def test_query_stream_w_simple_field_in_op(query_docs, database): for key, value in values.items(): assert stored[key] == value assert value["a"] == 1 + await verify_pipeline(query) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) async def test_query_stream_w_simple_field_array_contains_any_op(query_docs, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) @@ -1267,9 +1333,10 @@ async def test_query_stream_w_simple_field_array_contains_any_op(query_docs, dat for key, value in values.items(): assert stored[key] == value assert value["a"] == 1 + await verify_pipeline(query) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) async def test_query_stream_w_order_by(query_docs, database): collection, stored, allowed_vals = query_docs query = collection.order_by("b", direction=firestore.Query.DESCENDING) @@ -1281,9 +1348,10 @@ async def test_query_stream_w_order_by(query_docs, database): b_vals.append(value["b"]) # Make sure the ``b``-values are in DESCENDING order. assert sorted(b_vals, reverse=True) == b_vals + await verify_pipeline(query) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) async def test_query_stream_w_field_path(query_docs, database): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("stats.sum", ">", 4)) @@ -1303,9 +1371,10 @@ async def test_query_stream_w_field_path(query_docs, database): ] ) assert expected_ab_pairs == ab_pairs2 + await verify_pipeline(query) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_query_stream_w_start_end_cursor(query_docs, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) @@ -1321,16 +1390,17 @@ async def test_query_stream_w_start_end_cursor(query_docs, database): assert value["a"] == num_vals - 2 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) async def test_query_stream_wo_results(query_docs, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) query = collection.where(filter=FieldFilter("b", "==", num_vals + 100)) values = [i async for i in query.stream()] assert len(values) == 0 + await verify_pipeline(query) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_query_stream_w_projection(query_docs, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) @@ -1347,7 +1417,7 @@ async def test_query_stream_w_projection(query_docs, database): assert expected == value -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) async def test_query_stream_w_multiple_filters(query_docs, database): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("stats.product", ">", 5)).where( @@ -1365,9 +1435,10 @@ async def test_query_stream_w_multiple_filters(query_docs, database): assert stored[key] == value pair = (value["a"], value["b"]) assert pair in matching_pairs + await verify_pipeline(query) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) async def test_query_stream_w_offset(query_docs, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) @@ -1381,13 +1452,14 @@ async def test_query_stream_w_offset(query_docs, database): for key, value in values.items(): assert stored[key] == value assert value["b"] == 2 + await verify_pipeline(query) @pytest.mark.skipif( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) @pytest.mark.parametrize("method", ["stream", "get"]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) async def test_query_stream_or_get_w_no_explain_options(query_docs, database, method): from google.cloud.firestore_v1.query_profile import QueryExplainError @@ -1406,13 +1478,14 @@ async def test_query_stream_or_get_w_no_explain_options(query_docs, database, me # is called with pytest.raises(QueryExplainError, match="explain_options not set on query"): await results.get_explain_metrics() + await verify_pipeline(query) @pytest.mark.skipif( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) @pytest.mark.parametrize("method", ["stream", "get"]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_query_stream_or_get_w_explain_options_analyze_true( query_docs, database, method ): @@ -1457,7 +1530,7 @@ async def test_query_stream_or_get_w_explain_options_analyze_true( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) @pytest.mark.parametrize("method", ["stream", "get"]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_query_stream_or_get_w_explain_options_analyze_false( query_docs, database, method ): @@ -1492,7 +1565,172 @@ async def test_query_stream_or_get_w_explain_options_analyze_false( _verify_explain_metrics_analyze_false(explain_metrics) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.skipif( + FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." +) +@pytest.mark.parametrize("method", ["execute", "stream"]) +@pytest.mark.parametrize("database", [FIRESTORE_ENTERPRISE_DB], indirect=True) +async def test_pipeline_explain_options_explain_mode(database, method, query_docs): + """Explain currently not supported by backend. Expect error""" + from google.api_core.exceptions import InvalidArgument + from google.cloud.firestore_v1.query_profile import ( + PipelineExplainOptions, + ) + + collection, _, _ = query_docs + client = collection._client + query = collection.where(filter=FieldFilter("a", "==", 1)) + pipeline = client.pipeline().create_from(query) + + method_under_test = getattr(pipeline, method) + explain_options = PipelineExplainOptions(mode="explain") + + with pytest.raises(InvalidArgument) as e: + if method == "stream": + results = method_under_test(explain_options=explain_options) + _ = [i async for i in results] + else: + await method_under_test(explain_options=explain_options) + + assert "Explain execution mode is not supported" in str(e.value) + + +@pytest.mark.skipif( + FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." +) +@pytest.mark.parametrize("method", ["execute", "stream"]) +@pytest.mark.parametrize("database", [FIRESTORE_ENTERPRISE_DB], indirect=True) +async def test_pipeline_explain_options_analyze_mode(database, method, query_docs): + from google.cloud.firestore_v1.query_profile import ( + PipelineExplainOptions, + ExplainStats, + QueryExplainError, + ) + from google.cloud.firestore_v1.types.explain_stats import ( + ExplainStats as ExplainStats_pb, + ) + + collection, _, allowed_vals = query_docs + client = collection._client + query = collection.where(filter=FieldFilter("a", "==", 1)) + pipeline = client.pipeline().create_from(query) + + method_under_test = getattr(pipeline, method) + explain_options = PipelineExplainOptions() + + if method == "execute": + results = await method_under_test(explain_options=explain_options) + num_results = len(results) + else: + results = method_under_test(explain_options=explain_options) + with pytest.raises( + QueryExplainError, + match="explain_stats not available until query is complete", + ): + results.explain_stats + + num_results = len([item async for item in results]) + + explain_stats = results.explain_stats + + assert num_results == len(allowed_vals) + + assert isinstance(explain_stats, ExplainStats) + assert isinstance(explain_stats.get_raw(), ExplainStats_pb) + text_stats = explain_stats.get_text() + assert "Execution:" in text_stats + + +@pytest.mark.skipif( + FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." +) +@pytest.mark.parametrize("method", ["execute", "stream"]) +@pytest.mark.parametrize("database", [FIRESTORE_ENTERPRISE_DB], indirect=True) +async def test_pipeline_explain_options_using_additional_options( + database, method, query_docs +): + """additional_options field allows passing in arbitrary options. Test with explain_options""" + from google.cloud.firestore_v1.query_profile import ( + PipelineExplainOptions, + ExplainStats, + ) + from google.cloud.firestore_v1.types.explain_stats import ( + ExplainStats as ExplainStats_pb, + ) + + collection, _, allowed_vals = query_docs + client = collection._client + query = collection.where(filter=FieldFilter("a", "==", 1)) + pipeline = client.pipeline().create_from(query) + + method_under_test = getattr(pipeline, method) + encoded_options = {"explain_options": PipelineExplainOptions()._to_value()} + + stub = method_under_test( + explain_options=mock.Mock(), additional_options=encoded_options + ) + if method == "execute": + results = await stub + num_results = len(results) + else: + results = stub + num_results = len([item async for item in results]) + + assert num_results == len(allowed_vals) + + explain_stats = results.explain_stats + assert isinstance(explain_stats, ExplainStats) + assert isinstance(explain_stats.get_raw(), ExplainStats_pb) + text_stats = explain_stats.get_text() + assert "Execution:" in text_stats + + +@pytest.mark.skipif( + FIRESTORE_EMULATOR, reason="Pipeline query not supported in emulator." +) +@pytest.mark.parametrize("database", [FIRESTORE_ENTERPRISE_DB], indirect=True) +async def test_pipeline_w_read_time(query_docs, cleanup, database): + collection, stored, allowed_vals = query_docs + num_vals = len(allowed_vals) + + # Find a read_time before adding the new document. + snapshots = await collection.get() + read_time = snapshots[0].read_time + + new_data = { + "a": 9000, + "b": 1, + } + _, new_ref = await collection.add(new_data) + # Add to clean-up. + cleanup(new_ref.delete) + stored[new_ref.id] = new_data + client = collection._client + query = collection.where(filter=FieldFilter("b", "==", 1)) + pipeline = client.pipeline().create_from(query) + + # new query should have new_data + new_results = [result async for result in pipeline.stream()] + new_values = {result.ref.id: result.data() for result in new_results} + assert len(new_values) == num_vals + 1 + assert new_ref.id in new_values + assert new_values[new_ref.id] == new_data + + # pipeline with read_time should not have new_data + results = [result async for result in pipeline.stream(read_time=read_time)] + + values = {result.ref.id: result.data() for result in results} + + assert len(values) == num_vals + assert new_ref.id not in values + for key, value in values.items(): + assert stored[key] == value + assert value["b"] == 1 + assert value["a"] != 9000 + assert key != new_ref.id + + +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_query_stream_w_read_time(query_docs, cleanup, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) @@ -1532,7 +1770,7 @@ async def test_query_stream_w_read_time(query_docs, cleanup, database): assert new_values[new_ref.id] == new_data -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_query_with_order_dot_key(client, cleanup, database): db = client collection_id = "collek" + UNIQUE_RESOURCE_ID @@ -1572,7 +1810,7 @@ async def test_query_with_order_dot_key(client, cleanup, database): assert found_data == [snap.to_dict() for snap in cursor_with_key_data] -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_query_unary(client, cleanup, database): collection_name = "unary" + UNIQUE_RESOURCE_ID collection = client.collection(collection_name) @@ -1598,6 +1836,7 @@ async def test_query_unary(client, cleanup, database): snapshot0 = values0[0] assert snapshot0.reference._path == document0._path assert snapshot0.to_dict() == {field_name: None} + await verify_pipeline(query0) # 1. Query for a NAN. query1 = collection.where(filter=FieldFilter(field_name, "==", nan_val)) @@ -1608,6 +1847,7 @@ async def test_query_unary(client, cleanup, database): data1 = snapshot1.to_dict() assert len(data1) == 1 assert math.isnan(data1[field_name]) + await verify_pipeline(query1) # 2. Query for not null query2 = collection.where(filter=FieldFilter(field_name, "!=", None)) @@ -1627,7 +1867,7 @@ async def test_query_unary(client, cleanup, database): assert snapshot3.to_dict() == {field_name: 123} -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) async def test_collection_group_queries(client, cleanup, database): collection_group = "b" + UNIQUE_RESOURCE_ID @@ -1657,10 +1897,11 @@ async def test_collection_group_queries(client, cleanup, database): snapshots = [i async for i in query.stream()] found = [snapshot.id for snapshot in snapshots] expected = ["cg-doc1", "cg-doc2", "cg-doc3", "cg-doc4", "cg-doc5"] - assert found == expected + assert set(found) == set(expected) + await verify_pipeline(query) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_collection_group_queries_startat_endat(client, cleanup, database): collection_group = "b" + UNIQUE_RESOURCE_ID @@ -1703,7 +1944,7 @@ async def test_collection_group_queries_startat_endat(client, cleanup, database) assert found == set(["cg-doc2"]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) async def test_collection_group_queries_filters(client, cleanup, database): collection_group = "b" + UNIQUE_RESOURCE_ID @@ -1745,6 +1986,7 @@ async def test_collection_group_queries_filters(client, cleanup, database): snapshots = [i async for i in query.stream()] found = set(snapshot.id for snapshot in snapshots) assert found == set(["cg-doc2", "cg-doc3", "cg-doc4"]) + await verify_pipeline(query) query = ( client.collection_group(collection_group) @@ -1766,13 +2008,14 @@ async def test_collection_group_queries_filters(client, cleanup, database): snapshots = [i async for i in query.stream()] found = set(snapshot.id for snapshot in snapshots) assert found == set(["cg-doc2"]) + await verify_pipeline(query) @pytest.mark.skipif( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) @pytest.mark.parametrize("method", ["stream", "get"]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_collection_stream_or_get_w_no_explain_options( query_docs, database, method ): @@ -1797,7 +2040,7 @@ async def test_collection_stream_or_get_w_no_explain_options( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) @pytest.mark.parametrize("method", ["stream", "get"]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_collection_stream_or_get_w_explain_options_analyze_true( query_docs, database, method ): @@ -1865,7 +2108,7 @@ async def test_collection_stream_or_get_w_explain_options_analyze_true( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) @pytest.mark.parametrize("method", ["stream", "get"]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_collection_stream_or_get_w_explain_options_analyze_false( query_docs, database, method ): @@ -1919,7 +2162,7 @@ async def test_collection_stream_or_get_w_explain_options_analyze_false( @pytest.mark.skipif( FIRESTORE_EMULATOR, reason="PartitionQuery not implemented in emulator" ) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_partition_query_no_partitions(client, cleanup, database): collection_group = "b" + UNIQUE_RESOURCE_ID @@ -1953,7 +2196,7 @@ async def test_partition_query_no_partitions(client, cleanup, database): @pytest.mark.skipif( FIRESTORE_EMULATOR, reason="PartitionQuery not implemented in emulator" ) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_partition_query(client, cleanup, database): collection_group = "b" + UNIQUE_RESOURCE_ID n_docs = 128 * 2 + 127 # Minimum partition size is 128 @@ -1980,7 +2223,7 @@ async def test_partition_query(client, cleanup, database): @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Internal Issue b/137865992") -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_get_all(client, cleanup, database): collection_name = "get-all" + UNIQUE_RESOURCE_ID @@ -2053,7 +2296,7 @@ async def test_get_all(client, cleanup, database): assert not snapshots[2].exists -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_live_bulk_writer(client, cleanup, database): from google.cloud.firestore_v1.async_client import AsyncClient from google.cloud.firestore_v1.bulk_writer import BulkWriter @@ -2077,7 +2320,7 @@ async def test_live_bulk_writer(client, cleanup, database): assert len(await col.get()) == 50 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_batch(client, cleanup, database): collection_name = "batch" + UNIQUE_RESOURCE_ID @@ -2244,7 +2487,7 @@ async def _do_recursive_delete(client, bulk_writer, empty_philosophers=False): ), f"Snapshot at Socrates{path} should have been deleted" -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_recursive_delete_parallelized(client, cleanup, database): from google.cloud.firestore_v1.bulk_writer import BulkWriterOptions, SendMode @@ -2252,7 +2495,7 @@ async def test_async_recursive_delete_parallelized(client, cleanup, database): await _do_recursive_delete(client, bw) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_recursive_delete_serialized(client, cleanup, database): from google.cloud.firestore_v1.bulk_writer import BulkWriterOptions, SendMode @@ -2260,7 +2503,7 @@ async def test_async_recursive_delete_serialized(client, cleanup, database): await _do_recursive_delete(client, bw) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_recursive_delete_parallelized_empty(client, cleanup, database): from google.cloud.firestore_v1.bulk_writer import BulkWriterOptions, SendMode @@ -2268,7 +2511,7 @@ async def test_async_recursive_delete_parallelized_empty(client, cleanup, databa await _do_recursive_delete(client, bw, empty_philosophers=True) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_recursive_delete_serialized_empty(client, cleanup, database): from google.cloud.firestore_v1.bulk_writer import BulkWriterOptions, SendMode @@ -2276,7 +2519,7 @@ async def test_async_recursive_delete_serialized_empty(client, cleanup, database await _do_recursive_delete(client, bw, empty_philosophers=True) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_recursive_query(client, cleanup, database): col_id: str = f"philosophers-recursive-async-query{UNIQUE_RESOURCE_ID}" await _persist_documents(client, col_id, philosophers_data_set, cleanup) @@ -2315,7 +2558,7 @@ async def test_recursive_query(client, cleanup, database): assert ids[index] == expected_ids[index], error_msg -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_nested_recursive_query(client, cleanup, database): col_id: str = f"philosophers-nested-recursive-async-query{UNIQUE_RESOURCE_ID}" await _persist_documents(client, col_id, philosophers_data_set, cleanup) @@ -2339,7 +2582,7 @@ async def test_nested_recursive_query(client, cleanup, database): assert ids[index] == expected_ids[index], error_msg -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_chunked_query(client, cleanup, database): col = client.collection(f"async-chunked-test{UNIQUE_RESOURCE_ID}") for index in range(10): @@ -2355,7 +2598,7 @@ async def test_chunked_query(client, cleanup, database): assert lengths[3] == 1 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_chunked_query_smaller_limit(client, cleanup, database): col = client.collection(f"chunked-test-smaller-limit{UNIQUE_RESOURCE_ID}") for index in range(10): @@ -2368,7 +2611,7 @@ async def test_chunked_query_smaller_limit(client, cleanup, database): assert lengths[0] == 5 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_chunked_and_recursive(client, cleanup, database): col_id = f"chunked-async-recursive-test{UNIQUE_RESOURCE_ID}" documents = [ @@ -2427,7 +2670,7 @@ async def _chain(*iterators): yield value -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_count_async_query_get_default_alias(async_query, database): count_query = async_query.count() result = await count_query.get() @@ -2435,7 +2678,7 @@ async def test_count_async_query_get_default_alias(async_query, database): assert r.alias == "field_1" -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_count_query_get_with_alias(async_query, database): count_query = async_query.count(alias="total") result = await count_query.get() @@ -2443,7 +2686,7 @@ async def test_async_count_query_get_with_alias(async_query, database): assert r.alias == "total" -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_count_query_get_with_limit(async_query, database): count_query = async_query.count(alias="total") result = await count_query.get() @@ -2459,7 +2702,7 @@ async def test_async_count_query_get_with_limit(async_query, database): assert r.value == 2 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_count_query_get_multiple_aggregations(async_query, database): count_query = async_query.count(alias="total").count(alias="all") @@ -2474,7 +2717,7 @@ async def test_async_count_query_get_multiple_aggregations(async_query, database assert found_alias == set(expected_aliases) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_count_query_get_multiple_aggregations_duplicated_alias( async_query, database ): @@ -2486,7 +2729,7 @@ async def test_async_count_query_get_multiple_aggregations_duplicated_alias( assert "Aggregation aliases contain duplicate alias" in exc_info.value.message -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_count_query_get_empty_aggregation(async_query, database): from google.cloud.firestore_v1.async_aggregation import AsyncAggregationQuery @@ -2498,7 +2741,7 @@ async def test_async_count_query_get_empty_aggregation(async_query, database): assert "Aggregations can not be empty" in exc_info.value.message -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_count_query_stream_default_alias(async_query, database): count_query = async_query.count() @@ -2507,7 +2750,7 @@ async def test_async_count_query_stream_default_alias(async_query, database): assert aggregation_result.alias == "field_1" -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_count_query_stream_with_alias(async_query, database): count_query = async_query.count(alias="total") async for result in count_query.stream(): @@ -2515,7 +2758,7 @@ async def test_async_count_query_stream_with_alias(async_query, database): assert aggregation_result.alias == "total" -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_count_query_stream_with_limit(async_query, database): # count without limit count_query = async_query.count(alias="total") @@ -2530,7 +2773,7 @@ async def test_async_count_query_stream_with_limit(async_query, database): assert aggregation_result.value == 2 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_count_query_stream_multiple_aggregations(async_query, database): count_query = async_query.count(alias="total").count(alias="all") @@ -2540,7 +2783,7 @@ async def test_async_count_query_stream_multiple_aggregations(async_query, datab assert aggregation_result.alias in ["total", "all"] -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_count_query_stream_multiple_aggregations_duplicated_alias( async_query, database ): @@ -2553,7 +2796,7 @@ async def test_async_count_query_stream_multiple_aggregations_duplicated_alias( assert "Aggregation aliases contain duplicate alias" in exc_info.value.message -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_count_query_stream_empty_aggregation(async_query, database): from google.cloud.firestore_v1.async_aggregation import AsyncAggregationQuery @@ -2566,7 +2809,7 @@ async def test_async_count_query_stream_empty_aggregation(async_query, database) assert "Aggregations can not be empty" in exc_info.value.message -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_sum_query_get_default_alias(collection, database): sum_query = collection.sum("stats.product") result = await sum_query.get() @@ -2575,7 +2818,7 @@ async def test_async_sum_query_get_default_alias(collection, database): assert r.value == 100 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_sum_query_get_with_alias(collection, database): sum_query = collection.sum("stats.product", alias="total") result = await sum_query.get() @@ -2584,7 +2827,7 @@ async def test_async_sum_query_get_with_alias(collection, database): assert r.value == 100 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_sum_query_get_with_limit(collection, database): sum_query = collection.sum("stats.product", alias="total") result = await sum_query.get() @@ -2600,7 +2843,7 @@ async def test_async_sum_query_get_with_limit(collection, database): assert r.value == 5 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_sum_query_get_multiple_aggregations(collection, database): sum_query = collection.sum("stats.product", alias="total").sum( "stats.product", alias="all" @@ -2617,7 +2860,7 @@ async def test_async_sum_query_get_multiple_aggregations(collection, database): assert found_alias == set(expected_aliases) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_sum_query_stream_default_alias(collection, database): sum_query = collection.sum("stats.product") @@ -2627,7 +2870,7 @@ async def test_async_sum_query_stream_default_alias(collection, database): assert aggregation_result.value == 100 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_sum_query_stream_with_alias(collection, database): sum_query = collection.sum("stats.product", alias="total") async for result in sum_query.stream(): @@ -2635,7 +2878,7 @@ async def test_async_sum_query_stream_with_alias(collection, database): assert aggregation_result.alias == "total" -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_sum_query_stream_with_limit(collection, database): # sum without limit sum_query = collection.sum("stats.product", alias="total") @@ -2650,7 +2893,7 @@ async def test_async_sum_query_stream_with_limit(collection, database): assert aggregation_result.value == 5 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_sum_query_stream_multiple_aggregations(collection, database): sum_query = collection.sum("stats.product", alias="total").sum( "stats.product", alias="all" @@ -2662,7 +2905,7 @@ async def test_async_sum_query_stream_multiple_aggregations(collection, database assert aggregation_result.alias in ["total", "all"] -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_avg_query_get_default_alias(collection, database): avg_query = collection.avg("stats.product") result = await avg_query.get() @@ -2672,7 +2915,7 @@ async def test_async_avg_query_get_default_alias(collection, database): assert isinstance(r.value, float) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_avg_query_get_with_alias(collection, database): avg_query = collection.avg("stats.product", alias="total") result = await avg_query.get() @@ -2681,7 +2924,7 @@ async def test_async_avg_query_get_with_alias(collection, database): assert r.value == 4 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_avg_query_get_with_limit(collection, database): avg_query = collection.avg("stats.product", alias="total") result = await avg_query.get() @@ -2697,7 +2940,7 @@ async def test_async_avg_query_get_with_limit(collection, database): assert r.value == 5 / 12 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_avg_query_get_multiple_aggregations(collection, database): avg_query = collection.avg("stats.product", alias="total").avg( "stats.product", alias="all" @@ -2714,7 +2957,7 @@ async def test_async_avg_query_get_multiple_aggregations(collection, database): assert found_alias == set(expected_aliases) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_avg_query_get_w_no_explain_options(collection, database): avg_query = collection.avg("stats.product", alias="total") results = await avg_query.get() @@ -2725,7 +2968,7 @@ async def test_async_avg_query_get_w_no_explain_options(collection, database): @pytest.mark.skipif( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_avg_query_get_w_explain_options_analyze_true(collection, database): avg_query = collection.avg("stats.product", alias="total") results = await avg_query.get(explain_options=ExplainOptions(analyze=True)) @@ -2760,7 +3003,7 @@ async def test_async_avg_query_get_w_explain_options_analyze_true(collection, da @pytest.mark.skipif( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_avg_query_get_w_explain_options_analyze_false( collection, database ): @@ -2791,7 +3034,7 @@ async def test_async_avg_query_get_w_explain_options_analyze_false( explain_metrics.execution_stats -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_avg_query_stream_default_alias(collection, database): avg_query = collection.avg("stats.product") @@ -2802,7 +3045,7 @@ async def test_async_avg_query_stream_default_alias(collection, database): assert isinstance(aggregation_result.value, float) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_avg_query_stream_with_alias(collection, database): avg_query = collection.avg("stats.product", alias="total") async for result in avg_query.stream(): @@ -2810,7 +3053,7 @@ async def test_async_avg_query_stream_with_alias(collection, database): assert aggregation_result.alias == "total" -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_avg_query_stream_with_limit(collection, database): # avg without limit avg_query = collection.avg("stats.product", alias="total") @@ -2826,7 +3069,7 @@ async def test_async_avg_query_stream_with_limit(collection, database): assert isinstance(aggregation_result.value, float) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_avg_query_stream_multiple_aggregations(collection, database): avg_query = collection.avg("stats.product", alias="total").avg( "stats.product", alias="all" @@ -2838,7 +3081,7 @@ async def test_async_avg_query_stream_multiple_aggregations(collection, database assert aggregation_result.alias in ["total", "all"] -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_avg_query_stream_w_no_explain_options(collection, database): avg_query = collection.avg("stats.product", alias="total") results = avg_query.stream() @@ -2849,7 +3092,7 @@ async def test_async_avg_query_stream_w_no_explain_options(collection, database) @pytest.mark.skipif( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_avg_query_stream_w_explain_options_analyze_true( collection, database ): @@ -2894,7 +3137,7 @@ async def test_async_avg_query_stream_w_explain_options_analyze_true( @pytest.mark.skipif( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_avg_query_stream_w_explain_options_analyze_false( collection, database ): @@ -2926,7 +3169,7 @@ async def test_async_avg_query_stream_w_explain_options_analyze_false( explain_metrics.execution_stats -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @pytest.mark.parametrize( "aggregation_type,expected_value", [("count", 5), ("sum", 100), ("avg", 4.0)] ) @@ -2989,7 +3232,7 @@ async def create_in_transaction_helper( raise ValueError("Collection can't have more than 2 docs") -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_count_query_in_transaction(client, cleanup, database): collection_id = "doc-create" + UNIQUE_RESOURCE_ID document_id_1 = "doc1" + UNIQUE_RESOURCE_ID @@ -3021,7 +3264,7 @@ async def test_count_query_in_transaction(client, cleanup, database): assert r.value == 2 # there are still only 2 docs -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_query_with_and_composite_filter(query_docs, database): collection, stored, allowed_vals = query_docs and_filter = And( @@ -3037,7 +3280,7 @@ async def test_query_with_and_composite_filter(query_docs, database): assert result.get("stats.product") < 10 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_query_with_or_composite_filter(query_docs, database): collection, stored, allowed_vals = query_docs or_filter = Or( @@ -3061,7 +3304,7 @@ async def test_query_with_or_composite_filter(query_docs, database): assert lt_10 > 0 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_query_with_complex_composite_filter(query_docs, database): collection, stored, allowed_vals = query_docs field_filter = FieldFilter("b", "==", 0) @@ -3107,7 +3350,7 @@ async def test_query_with_complex_composite_filter(query_docs, database): assert b_not_3 is True -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_or_query_in_transaction(client, cleanup, database): collection_id = "doc-create" + UNIQUE_RESOURCE_ID document_id_1 = "doc1" + UNIQUE_RESOURCE_ID @@ -3171,7 +3414,7 @@ async def _make_transaction_query(client, cleanup): @pytest.mark.skipif( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_transaction_w_query_w_no_explain_options(client, cleanup, database): from google.cloud.firestore_v1.query_profile import QueryExplainError @@ -3204,7 +3447,7 @@ async def in_transaction(transaction): @pytest.mark.skipif( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_transaction_w_query_w_explain_options_analyze_true( client, cleanup, database ): @@ -3246,7 +3489,7 @@ async def in_transaction(transaction): @pytest.mark.skipif( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_transaction_w_query_w_explain_options_analyze_false( client, cleanup, database ): @@ -3283,7 +3526,7 @@ async def in_transaction(transaction): @pytest.mark.skipif( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_query_in_transaction_w_no_explain_options(client, cleanup, database): from google.cloud.firestore_v1.query_profile import QueryExplainError @@ -3316,7 +3559,7 @@ async def in_transaction(transaction): @pytest.mark.skipif( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_query_in_transaction_w_explain_options_analyze_true( client, cleanup, database ): @@ -3350,7 +3593,7 @@ async def in_transaction(transaction): @pytest.mark.skipif( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_query_in_transaction_w_explain_options_analyze_false( client, cleanup, database ): @@ -3383,7 +3626,7 @@ async def in_transaction(transaction): assert inner_fn_ran is True -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_query_in_transaction_with_read_time(client, cleanup, database): """ Test query profiling in transactions. diff --git a/tests/unit/v1/test_aggregation.py b/tests/unit/v1/test_aggregation.py index 69ca69ec7..96928e88e 100644 --- a/tests/unit/v1/test_aggregation.py +++ b/tests/unit/v1/test_aggregation.py @@ -20,6 +20,7 @@ from google.cloud.firestore_v1.base_aggregation import ( AggregationResult, AvgAggregation, + BaseAggregation, CountAggregation, SumAggregation, ) @@ -27,6 +28,7 @@ from google.cloud.firestore_v1.query_results import QueryResultsList from google.cloud.firestore_v1.stream_generator import StreamGenerator from google.cloud.firestore_v1.types import RunAggregationQueryResponse +from google.cloud.firestore_v1.field_path import FieldPath from google.protobuf.timestamp_pb2 import Timestamp from tests.unit.v1._test_helpers import ( make_aggregation_query, @@ -121,6 +123,63 @@ def test_avg_aggregation_no_alias_to_pb(): assert got_pb.alias == "" +@pytest.mark.parametrize( + "in_alias,expected_alias", [("total", "total"), (None, "field_1")] +) +def test_count_aggregation_to_pipeline_expr(in_alias, expected_alias): + from google.cloud.firestore_v1.pipeline_expressions import AliasedExpression + from google.cloud.firestore_v1.pipeline_expressions import Count + + count_aggregation = CountAggregation(alias=in_alias) + got = count_aggregation._to_pipeline_expr(iter([1])) + assert isinstance(got, AliasedExpression) + assert got.alias == expected_alias + assert isinstance(got.expr, Count) + assert len(got.expr.params) == 0 + + +@pytest.mark.parametrize( + "in_alias,expected_path,expected_alias", + [("total", "path", "total"), (None, "some_ref", "field_1")], +) +def test_sum_aggregation_to_pipeline_expr(in_alias, expected_path, expected_alias): + from google.cloud.firestore_v1.pipeline_expressions import AliasedExpression + + count_aggregation = SumAggregation(expected_path, alias=in_alias) + got = count_aggregation._to_pipeline_expr(iter([1])) + assert isinstance(got, AliasedExpression) + assert got.alias == expected_alias + assert got.expr.name == "sum" + assert got.expr.params[0].path == expected_path + + +@pytest.mark.parametrize( + "in_alias,expected_path,expected_alias", + [("total", "path", "total"), (None, "some_ref", "field_1")], +) +def test_avg_aggregation_to_pipeline_expr(in_alias, expected_path, expected_alias): + from google.cloud.firestore_v1.pipeline_expressions import AliasedExpression + + count_aggregation = AvgAggregation(expected_path, alias=in_alias) + got = count_aggregation._to_pipeline_expr(iter([1])) + assert isinstance(got, AliasedExpression) + assert got.alias == expected_alias + assert got.expr.name == "average" + assert got.expr.params[0].path == expected_path + + +def test_aggregation__pipeline_alias_increment(): + """ + BaseAggregation.__pipeline_alias should pull from an autoindexer to populate field numbers + """ + autoindex = iter(range(10)) + mock_instance = mock.Mock() + mock_instance.alias = None + for i in range(10): + got_name = BaseAggregation._pipeline_alias(mock_instance, autoindex) + assert got_name == f"field_{i}" + + def test_aggregation_query_constructor(): client = make_client() parent = client.collection("dee") @@ -894,6 +953,16 @@ def test_aggregation_query_stream_w_explain_options_analyze_false(): _aggregation_query_stream_helper(explain_options=ExplainOptions(analyze=False)) +def test_aggretgation__to_dict(): + expected_alias = "alias" + expected_value = "value" + instance = AggregationResult(alias=expected_alias, value=expected_value) + dict_result = instance._to_dict() + assert len(dict_result) == 1 + assert next(iter(dict_result)) == expected_alias + assert dict_result[expected_alias] == expected_value + + def test_aggregation_from_query(): from google.cloud.firestore_v1 import _helpers @@ -952,3 +1021,144 @@ def test_aggregation_from_query(): metadata=client._rpc_metadata, **kwargs, ) + + +@pytest.mark.parametrize( + "field,in_alias,out_alias", + [ + ("field", None, "field_1"), + (FieldPath("test"), None, "field_1"), + ("field", "overwrite", "overwrite"), + ], +) +def test_aggreation_to_pipeline_sum(field, in_alias, out_alias): + from google.cloud.firestore_v1.pipeline import Pipeline + from google.cloud.firestore_v1.pipeline_stages import Collection, Aggregate + + client = make_client() + parent = client.collection("dee") + query = make_query(parent) + aggregation_query = make_aggregation_query(query) + aggregation_query.sum(field, alias=in_alias) + pipeline = aggregation_query._build_pipeline(client.pipeline()) + assert isinstance(pipeline, Pipeline) + assert len(pipeline.stages) == 2 + assert isinstance(pipeline.stages[0], Collection) + assert pipeline.stages[0].path == "/dee" + aggregate_stage = pipeline.stages[1] + assert isinstance(aggregate_stage, Aggregate) + assert len(aggregate_stage.accumulators) == 1 + assert aggregate_stage.accumulators[0].expr.name == "sum" + expected_field = field if isinstance(field, str) else field.to_api_repr() + assert aggregate_stage.accumulators[0].expr.params[0].path == expected_field + assert aggregate_stage.accumulators[0].alias == out_alias + + +@pytest.mark.parametrize( + "field,in_alias,out_alias", + [ + ("field", None, "field_1"), + (FieldPath("test"), None, "field_1"), + ("field", "overwrite", "overwrite"), + ], +) +def test_aggreation_to_pipeline_avg(field, in_alias, out_alias): + from google.cloud.firestore_v1.pipeline import Pipeline + from google.cloud.firestore_v1.pipeline_stages import Collection, Aggregate + + client = make_client() + parent = client.collection("dee") + query = make_query(parent) + aggregation_query = make_aggregation_query(query) + aggregation_query.avg(field, alias=in_alias) + pipeline = aggregation_query._build_pipeline(client.pipeline()) + assert isinstance(pipeline, Pipeline) + assert len(pipeline.stages) == 2 + assert isinstance(pipeline.stages[0], Collection) + assert pipeline.stages[0].path == "/dee" + aggregate_stage = pipeline.stages[1] + assert isinstance(aggregate_stage, Aggregate) + assert len(aggregate_stage.accumulators) == 1 + assert aggregate_stage.accumulators[0].expr.name == "average" + expected_field = field if isinstance(field, str) else field.to_api_repr() + assert aggregate_stage.accumulators[0].expr.params[0].path == expected_field + assert aggregate_stage.accumulators[0].alias == out_alias + + +@pytest.mark.parametrize( + "in_alias,out_alias", + [ + (None, "field_1"), + ("overwrite", "overwrite"), + ], +) +def test_aggreation_to_pipeline_count(in_alias, out_alias): + from google.cloud.firestore_v1.pipeline import Pipeline + from google.cloud.firestore_v1.pipeline_stages import Collection, Aggregate + from google.cloud.firestore_v1.pipeline_expressions import Count + + client = make_client() + parent = client.collection("dee") + query = make_query(parent) + aggregation_query = make_aggregation_query(query) + aggregation_query.count(alias=in_alias) + pipeline = aggregation_query._build_pipeline(client.pipeline()) + assert isinstance(pipeline, Pipeline) + assert len(pipeline.stages) == 2 + assert isinstance(pipeline.stages[0], Collection) + assert pipeline.stages[0].path == "/dee" + aggregate_stage = pipeline.stages[1] + assert isinstance(aggregate_stage, Aggregate) + assert len(aggregate_stage.accumulators) == 1 + assert isinstance(aggregate_stage.accumulators[0].expr, Count) + assert aggregate_stage.accumulators[0].alias == out_alias + + +def test_aggreation_to_pipeline_count_increment(): + """ + When aliases aren't given, should assign incrementing field_n values + """ + from google.cloud.firestore_v1.pipeline_expressions import Count + + n = 100 + client = make_client() + parent = client.collection("dee") + query = make_query(parent) + aggregation_query = make_aggregation_query(query) + for _ in range(n): + aggregation_query.count() + pipeline = aggregation_query._build_pipeline(client.pipeline()) + aggregate_stage = pipeline.stages[1] + assert len(aggregate_stage.accumulators) == n + for i in range(n): + assert isinstance(aggregate_stage.accumulators[i].expr, Count) + assert aggregate_stage.accumulators[i].alias == f"field_{i + 1}" + + +def test_aggreation_to_pipeline_complex(): + from google.cloud.firestore_v1.pipeline import Pipeline + from google.cloud.firestore_v1.pipeline_stages import Collection, Aggregate, Select + + client = make_client() + query = client.collection("my_col").select(["field_a", "field_b.c"]) + aggregation_query = make_aggregation_query(query) + aggregation_query.sum("field", alias="alias") + aggregation_query.count() + aggregation_query.avg("other") + aggregation_query.sum("final") + pipeline = aggregation_query._build_pipeline(client.pipeline()) + assert isinstance(pipeline, Pipeline) + assert len(pipeline.stages) == 3 + assert isinstance(pipeline.stages[0], Collection) + assert isinstance(pipeline.stages[1], Select) + assert isinstance(pipeline.stages[2], Aggregate) + aggregate_stage = pipeline.stages[2] + assert len(aggregate_stage.accumulators) == 4 + assert aggregate_stage.accumulators[0].expr.name == "sum" + assert aggregate_stage.accumulators[0].alias == "alias" + assert aggregate_stage.accumulators[1].expr.name == "count" + assert aggregate_stage.accumulators[1].alias == "field_1" + assert aggregate_stage.accumulators[2].expr.name == "average" + assert aggregate_stage.accumulators[2].alias == "field_2" + assert aggregate_stage.accumulators[3].expr.name == "sum" + assert aggregate_stage.accumulators[3].alias == "field_3" diff --git a/tests/unit/v1/test_async_aggregation.py b/tests/unit/v1/test_async_aggregation.py index 9140f53e8..025146145 100644 --- a/tests/unit/v1/test_async_aggregation.py +++ b/tests/unit/v1/test_async_aggregation.py @@ -31,6 +31,7 @@ from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator from google.cloud.firestore_v1.query_profile import ExplainMetrics, QueryExplainError from google.cloud.firestore_v1.query_results import QueryResultsList +from google.cloud.firestore_v1.field_path import FieldPath _PROJECT = "PROJECT" @@ -696,3 +697,144 @@ async def test_aggregation_query_stream_w_explain_options_analyze_false(): explain_options = ExplainOptions(analyze=False) await _async_aggregation_query_stream_helper(explain_options=explain_options) + + +@pytest.mark.parametrize( + "field,in_alias,out_alias", + [ + ("field", None, "field_1"), + (FieldPath("test"), None, "field_1"), + ("field", "overwrite", "overwrite"), + ], +) +def test_async_aggreation_to_pipeline_sum(field, in_alias, out_alias): + from google.cloud.firestore_v1.async_pipeline import AsyncPipeline + from google.cloud.firestore_v1.pipeline_stages import Collection, Aggregate + + client = make_async_client() + parent = client.collection("dee") + query = make_async_query(parent) + aggregation_query = make_async_aggregation_query(query) + aggregation_query.sum(field, alias=in_alias) + pipeline = aggregation_query._build_pipeline(client.pipeline()) + assert isinstance(pipeline, AsyncPipeline) + assert len(pipeline.stages) == 2 + assert isinstance(pipeline.stages[0], Collection) + assert pipeline.stages[0].path == "/dee" + aggregate_stage = pipeline.stages[1] + assert isinstance(aggregate_stage, Aggregate) + assert len(aggregate_stage.accumulators) == 1 + assert aggregate_stage.accumulators[0].expr.name == "sum" + expected_field = field if isinstance(field, str) else field.to_api_repr() + assert aggregate_stage.accumulators[0].expr.params[0].path == expected_field + assert aggregate_stage.accumulators[0].alias == out_alias + + +@pytest.mark.parametrize( + "field,in_alias,out_alias", + [ + ("field", None, "field_1"), + (FieldPath("test"), None, "field_1"), + ("field", "overwrite", "overwrite"), + ], +) +def test_async_aggreation_to_pipeline_avg(field, in_alias, out_alias): + from google.cloud.firestore_v1.async_pipeline import AsyncPipeline + from google.cloud.firestore_v1.pipeline_stages import Collection, Aggregate + + client = make_async_client() + parent = client.collection("dee") + query = make_async_query(parent) + aggregation_query = make_async_aggregation_query(query) + aggregation_query.avg(field, alias=in_alias) + pipeline = aggregation_query._build_pipeline(client.pipeline()) + assert isinstance(pipeline, AsyncPipeline) + assert len(pipeline.stages) == 2 + assert isinstance(pipeline.stages[0], Collection) + assert pipeline.stages[0].path == "/dee" + aggregate_stage = pipeline.stages[1] + assert isinstance(aggregate_stage, Aggregate) + assert len(aggregate_stage.accumulators) == 1 + assert aggregate_stage.accumulators[0].expr.name == "average" + expected_field = field if isinstance(field, str) else field.to_api_repr() + assert aggregate_stage.accumulators[0].expr.params[0].path == expected_field + assert aggregate_stage.accumulators[0].alias == out_alias + + +@pytest.mark.parametrize( + "in_alias,out_alias", + [ + (None, "field_1"), + ("overwrite", "overwrite"), + ], +) +def test_async_aggreation_to_pipeline_count(in_alias, out_alias): + from google.cloud.firestore_v1.async_pipeline import AsyncPipeline + from google.cloud.firestore_v1.pipeline_stages import Collection, Aggregate + from google.cloud.firestore_v1.pipeline_expressions import Count + + client = make_async_client() + parent = client.collection("dee") + query = make_async_query(parent) + aggregation_query = make_async_aggregation_query(query) + aggregation_query.count(alias=in_alias) + pipeline = aggregation_query._build_pipeline(client.pipeline()) + assert isinstance(pipeline, AsyncPipeline) + assert len(pipeline.stages) == 2 + assert isinstance(pipeline.stages[0], Collection) + assert pipeline.stages[0].path == "/dee" + aggregate_stage = pipeline.stages[1] + assert isinstance(aggregate_stage, Aggregate) + assert len(aggregate_stage.accumulators) == 1 + assert isinstance(aggregate_stage.accumulators[0].expr, Count) + assert aggregate_stage.accumulators[0].alias == out_alias + + +def test_aggreation_to_pipeline_count_increment(): + """ + When aliases aren't given, should assign incrementing field_n values + """ + from google.cloud.firestore_v1.pipeline_expressions import Count + + n = 100 + client = make_async_client() + parent = client.collection("dee") + query = make_async_query(parent) + aggregation_query = make_async_aggregation_query(query) + for _ in range(n): + aggregation_query.count() + pipeline = aggregation_query._build_pipeline(client.pipeline()) + aggregate_stage = pipeline.stages[1] + assert len(aggregate_stage.accumulators) == n + for i in range(n): + assert isinstance(aggregate_stage.accumulators[i].expr, Count) + assert aggregate_stage.accumulators[i].alias == f"field_{i + 1}" + + +def test_async_aggreation_to_pipeline_complex(): + from google.cloud.firestore_v1.async_pipeline import AsyncPipeline + from google.cloud.firestore_v1.pipeline_stages import Collection, Aggregate, Select + + client = make_async_client() + query = client.collection("my_col").select(["field_a", "field_b.c"]) + aggregation_query = make_async_aggregation_query(query) + aggregation_query.sum("field", alias="alias") + aggregation_query.count() + aggregation_query.avg("other") + aggregation_query.sum("final") + pipeline = aggregation_query._build_pipeline(client.pipeline()) + assert isinstance(pipeline, AsyncPipeline) + assert len(pipeline.stages) == 3 + assert isinstance(pipeline.stages[0], Collection) + assert isinstance(pipeline.stages[1], Select) + assert isinstance(pipeline.stages[2], Aggregate) + aggregate_stage = pipeline.stages[2] + assert len(aggregate_stage.accumulators) == 4 + assert aggregate_stage.accumulators[0].expr.name == "sum" + assert aggregate_stage.accumulators[0].alias == "alias" + assert aggregate_stage.accumulators[1].expr.name == "count" + assert aggregate_stage.accumulators[1].alias == "field_1" + assert aggregate_stage.accumulators[2].expr.name == "average" + assert aggregate_stage.accumulators[2].alias == "field_2" + assert aggregate_stage.accumulators[3].expr.name == "sum" + assert aggregate_stage.accumulators[3].alias == "field_3" diff --git a/tests/unit/v1/test_async_client.py b/tests/unit/v1/test_async_client.py index 9b49e5bf0..3aeef8f9f 100644 --- a/tests/unit/v1/test_async_client.py +++ b/tests/unit/v1/test_async_client.py @@ -563,6 +563,17 @@ def test_asyncclient_transaction(): assert transaction._id is None +def test_asyncclient_pipeline(): + from google.cloud.firestore_v1.async_pipeline import AsyncPipeline + from google.cloud.firestore_v1.pipeline_source import PipelineSource + + client = _make_default_async_client() + ppl = client.pipeline() + assert client._pipeline_cls == AsyncPipeline + assert isinstance(ppl, PipelineSource) + assert ppl.client == client + + def _make_credentials(): import google.auth.credentials diff --git a/tests/unit/v1/test_async_collection.py b/tests/unit/v1/test_async_collection.py index a0194ace5..34a259996 100644 --- a/tests/unit/v1/test_async_collection.py +++ b/tests/unit/v1/test_async_collection.py @@ -601,3 +601,17 @@ def test_asynccollectionreference_recursive(): col = _make_async_collection_reference("collection") assert isinstance(col.recursive(), AsyncQuery) + + +def test_asynccollectionreference_pipeline(): + from google.cloud.firestore_v1.async_pipeline import AsyncPipeline + from google.cloud.firestore_v1.pipeline_stages import Collection + + client = make_async_client() + collection = _make_async_collection_reference("collection", client=client) + pipeline = collection._build_pipeline(client.pipeline()) + assert isinstance(pipeline, AsyncPipeline) + # should have single "Collection" stage + assert len(pipeline.stages) == 1 + assert isinstance(pipeline.stages[0], Collection) + assert pipeline.stages[0].path == "/collection" diff --git a/tests/unit/v1/test_async_pipeline.py b/tests/unit/v1/test_async_pipeline.py new file mode 100644 index 000000000..18805b7b2 --- /dev/null +++ b/tests/unit/v1/test_async_pipeline.py @@ -0,0 +1,441 @@ +# Copyright 2025 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +import mock +import pytest + +from google.cloud.firestore_v1 import pipeline_stages as stages +from google.cloud.firestore_v1.pipeline_expressions import Field + +from tests.unit.v1._test_helpers import make_async_client + + +def _make_async_pipeline(*args, client=mock.Mock()): + from google.cloud.firestore_v1.async_pipeline import AsyncPipeline + + return AsyncPipeline._create_with_stages(client, *args) + + +async def _async_it(list): + for value in list: + yield value + + +def test_ctor(): + from google.cloud.firestore_v1.async_pipeline import AsyncPipeline + + client = object() + instance = AsyncPipeline(client) + assert instance._client == client + assert len(instance.stages) == 0 + + +def test_create(): + from google.cloud.firestore_v1.async_pipeline import AsyncPipeline + + client = object() + stages = [object() for i in range(10)] + instance = AsyncPipeline._create_with_stages(client, *stages) + assert instance._client == client + assert len(instance.stages) == 10 + assert instance.stages[0] == stages[0] + assert instance.stages[-1] == stages[-1] + + +def test_async_pipeline_repr_empty(): + ppl = _make_async_pipeline() + repr_str = repr(ppl) + assert repr_str == "AsyncPipeline()" + + +def test_async_pipeline_repr_single_stage(): + stage = mock.Mock() + stage.__repr__ = lambda x: "SingleStage" + ppl = _make_async_pipeline(stage) + repr_str = repr(ppl) + assert repr_str == "AsyncPipeline(SingleStage)" + + +def test_async_pipeline_repr_multiple_stage(): + stage_1 = stages.Collection("path") + stage_2 = stages.RawStage("second", 2) + stage_3 = stages.RawStage("third", 3) + ppl = _make_async_pipeline(stage_1, stage_2, stage_3) + repr_str = repr(ppl) + assert repr_str == ( + "AsyncPipeline(\n" + " Collection(path='/path'),\n" + " RawStage(name='second'),\n" + " RawStage(name='third')\n" + ")" + ) + + +def test_async_pipeline_repr_long(): + num_stages = 100 + stage_list = [stages.RawStage("custom", i) for i in range(num_stages)] + ppl = _make_async_pipeline(*stage_list) + repr_str = repr(ppl) + assert repr_str.count("RawStage") == num_stages + assert repr_str.count("\n") == num_stages + 1 + + +def test_async_pipeline__to_pb(): + from google.cloud.firestore_v1.types.pipeline import StructuredPipeline + + stage_1 = stages.RawStage("first") + stage_2 = stages.RawStage("second") + ppl = _make_async_pipeline(stage_1, stage_2) + pb = ppl._to_pb() + assert isinstance(pb, StructuredPipeline) + assert pb.pipeline.stages[0] == stage_1._to_pb() + assert pb.pipeline.stages[1] == stage_2._to_pb() + + +def test_async_pipeline_append(): + """append should create a new pipeline with the additional stage""" + stage_1 = stages.RawStage("first") + ppl_1 = _make_async_pipeline(stage_1, client=object()) + stage_2 = stages.RawStage("second") + ppl_2 = ppl_1._append(stage_2) + assert ppl_1 != ppl_2 + assert len(ppl_1.stages) == 1 + assert len(ppl_2.stages) == 2 + assert ppl_2.stages[0] == stage_1 + assert ppl_2.stages[1] == stage_2 + assert ppl_1._client == ppl_2._client + assert isinstance(ppl_2, type(ppl_1)) + + +@pytest.mark.asyncio +async def test_async_pipeline_stream_empty(): + """ + test stream pipeline with mocked empty response + """ + from google.cloud.firestore_v1.types import ExecutePipelineResponse + from google.cloud.firestore_v1.types import ExecutePipelineRequest + + client = mock.Mock() + client.project = "A" + client._database = "B" + mock_rpc = mock.AsyncMock() + client._firestore_api.execute_pipeline = mock_rpc + mock_rpc.return_value = _async_it([ExecutePipelineResponse()]) + ppl_1 = _make_async_pipeline(stages.RawStage("s"), client=client) + + results = [r async for r in ppl_1.stream()] + assert results == [] + assert mock_rpc.call_count == 1 + request = mock_rpc.call_args[0][0] + assert isinstance(request, ExecutePipelineRequest) + assert request.structured_pipeline == ppl_1._to_pb() + assert request.database == "projects/A/databases/B" + + +@pytest.mark.asyncio +async def test_async_pipeline_stream_no_doc_ref(): + """ + test stream pipeline with no doc ref + """ + from google.cloud.firestore_v1.types import Document + from google.cloud.firestore_v1.types import ExecutePipelineResponse + from google.cloud.firestore_v1.types import ExecutePipelineRequest + from google.cloud.firestore_v1.pipeline_result import PipelineResult + + client = mock.Mock() + client.project = "A" + client._database = "B" + mock_rpc = mock.AsyncMock() + client._firestore_api.execute_pipeline = mock_rpc + mock_rpc.return_value = _async_it( + [ExecutePipelineResponse(results=[Document()], execution_time={"seconds": 9})] + ) + ppl_1 = _make_async_pipeline(stages.RawStage("s"), client=client) + + results = [r async for r in ppl_1.stream()] + assert len(results) == 1 + assert mock_rpc.call_count == 1 + request = mock_rpc.call_args[0][0] + assert isinstance(request, ExecutePipelineRequest) + assert request.structured_pipeline == ppl_1._to_pb() + assert request.database == "projects/A/databases/B" + assert request.transaction == b"" + + response = results[0] + assert isinstance(response, PipelineResult) + assert response.ref is None + assert response.id is None + assert response.create_time is None + assert response.update_time is None + assert response.execution_time.seconds == 9 + assert response.data() == {} + + +@pytest.mark.asyncio +async def test_async_pipeline_stream_populated(): + """ + test stream pipeline with fully populated doc ref + """ + from google.cloud.firestore_v1.types import Document + from google.cloud.firestore_v1.types import ExecutePipelineResponse + from google.cloud.firestore_v1.types import ExecutePipelineRequest + from google.cloud.firestore_v1.types import Value + from google.cloud.firestore_v1.async_document import AsyncDocumentReference + from google.cloud.firestore_v1.pipeline_result import PipelineResult + + real_client = make_async_client() + client = mock.Mock() + client.project = "A" + client._database = "B" + client.document = real_client.document + mock_rpc = mock.AsyncMock() + client._firestore_api.execute_pipeline = mock_rpc + + mock_rpc.return_value = _async_it( + [ + ExecutePipelineResponse( + results=[ + Document( + name="test/my_doc", + create_time={"seconds": 1}, + update_time={"seconds": 2}, + fields={"key": Value(string_value="str_val")}, + ) + ], + execution_time={"seconds": 9}, + ) + ] + ) + ppl_1 = _make_async_pipeline(client=client) + + results = [r async for r in ppl_1.stream()] + assert len(results) == 1 + assert mock_rpc.call_count == 1 + request = mock_rpc.call_args[0][0] + assert isinstance(request, ExecutePipelineRequest) + assert request.structured_pipeline == ppl_1._to_pb() + assert request.database == "projects/A/databases/B" + + response = results[0] + assert isinstance(response, PipelineResult) + assert isinstance(response.ref, AsyncDocumentReference) + assert response.ref.path == "test/my_doc" + assert response.id == "my_doc" + assert response.create_time.seconds == 1 + assert response.update_time.seconds == 2 + assert response.execution_time.seconds == 9 + assert response.data() == {"key": "str_val"} + + +@pytest.mark.asyncio +async def test_async_pipeline_stream_multiple(): + """ + test stream pipeline with multiple docs and responses + """ + from google.cloud.firestore_v1.types import Document + from google.cloud.firestore_v1.types import ExecutePipelineResponse + from google.cloud.firestore_v1.types import ExecutePipelineRequest + from google.cloud.firestore_v1.types import Value + from google.cloud.firestore_v1.pipeline_result import PipelineResult + + real_client = make_async_client() + client = mock.Mock() + client.project = "A" + client._database = "B" + client.document = real_client.document + mock_rpc = mock.AsyncMock() + client._firestore_api.execute_pipeline = mock_rpc + + mock_rpc.return_value = _async_it( + [ + ExecutePipelineResponse( + results=[ + Document(fields={"key": Value(integer_value=0)}), + Document(fields={"key": Value(integer_value=1)}), + ], + execution_time={"seconds": 0}, + ), + ExecutePipelineResponse( + results=[ + Document(fields={"key": Value(integer_value=2)}), + Document(fields={"key": Value(integer_value=3)}), + ], + execution_time={"seconds": 1}, + ), + ] + ) + ppl_1 = _make_async_pipeline(client=client) + + results = [r async for r in ppl_1.stream()] + assert len(results) == 4 + assert mock_rpc.call_count == 1 + request = mock_rpc.call_args[0][0] + assert isinstance(request, ExecutePipelineRequest) + assert request.structured_pipeline == ppl_1._to_pb() + assert request.database == "projects/A/databases/B" + + for idx, response in enumerate(results): + assert isinstance(response, PipelineResult) + assert response.data() == {"key": idx} + + +@pytest.mark.asyncio +async def test_async_pipeline_stream_with_transaction(): + """ + test stream pipeline with transaction context + """ + from google.cloud.firestore_v1.types import ExecutePipelineResponse + from google.cloud.firestore_v1.types import ExecutePipelineRequest + from google.cloud.firestore_v1.async_transaction import AsyncTransaction + + client = mock.Mock() + client.project = "A" + client._database = "B" + mock_rpc = mock.AsyncMock() + client._firestore_api.execute_pipeline = mock_rpc + + transaction = AsyncTransaction(client) + transaction._id = b"123" + + mock_rpc.return_value = _async_it([ExecutePipelineResponse()]) + ppl_1 = _make_async_pipeline(client=client) + + [r async for r in ppl_1.stream(transaction=transaction)] + assert mock_rpc.call_count == 1 + request = mock_rpc.call_args[0][0] + assert isinstance(request, ExecutePipelineRequest) + assert request.structured_pipeline == ppl_1._to_pb() + assert request.database == "projects/A/databases/B" + assert request.transaction == b"123" + + +@pytest.mark.asyncio +async def test_async_pipeline_stream_with_read_time(): + """ + test stream pipeline with read_time + """ + import datetime + + from google.cloud.firestore_v1.types import ExecutePipelineResponse + from google.cloud.firestore_v1.types import ExecutePipelineRequest + + client = mock.Mock() + client.project = "A" + client._database = "B" + mock_rpc = mock.AsyncMock() + client._firestore_api.execute_pipeline = mock_rpc + + read_time = datetime.datetime.now(tz=datetime.timezone.utc) + + mock_rpc.return_value = _async_it([ExecutePipelineResponse()]) + ppl_1 = _make_async_pipeline(client=client) + + [r async for r in ppl_1.stream(read_time=read_time)] + assert mock_rpc.call_count == 1 + request = mock_rpc.call_args[0][0] + assert isinstance(request, ExecutePipelineRequest) + assert request.structured_pipeline == ppl_1._to_pb() + assert request.database == "projects/A/databases/B" + assert request.read_time == read_time + + +@pytest.mark.asyncio +async def test_async_pipeline_stream_stream_equivalence(): + """ + Pipeline.stream should provide same results from pipeline.stream, as a list + """ + from google.cloud.firestore_v1.types import Document + from google.cloud.firestore_v1.types import ExecutePipelineResponse + from google.cloud.firestore_v1.types import Value + + real_client = make_async_client() + client = mock.Mock() + client.project = "A" + client._database = "B" + client.document = real_client.document + mock_rpc = mock.AsyncMock() + client._firestore_api.execute_pipeline = mock_rpc + mock_response = [ + ExecutePipelineResponse( + results=[ + Document( + name="test/my_doc", + fields={"key": Value(string_value="str_val")}, + ) + ], + ) + ] + mock_rpc.return_value = _async_it(mock_response) + ppl_1 = _make_async_pipeline(client=client) + + stream_results = [r async for r in ppl_1.stream()] + # reset response + mock_rpc.return_value = _async_it(mock_response) + stream_results = await ppl_1.execute() + assert stream_results == stream_results + assert stream_results[0].data()["key"] == "str_val" + assert stream_results[0].data()["key"] == "str_val" + + +@pytest.mark.parametrize( + "method,args,result_cls", + [ + ("add_fields", (Field.of("n"),), stages.AddFields), + ("remove_fields", ("name",), stages.RemoveFields), + ("remove_fields", (Field.of("n"),), stages.RemoveFields), + ("select", ("name",), stages.Select), + ("select", (Field.of("n"),), stages.Select), + ("where", (Field.of("n").exists(),), stages.Where), + ("find_nearest", ("name", [0.1], "cosine"), stages.FindNearest), + ( + "find_nearest", + ("name", [0.1], "cosine", stages.FindNearestOptions(10)), + stages.FindNearest, + ), + ("sort", (Field.of("n").descending(),), stages.Sort), + ("sort", (Field.of("n").descending(), Field.of("m").ascending()), stages.Sort), + ("sample", (10,), stages.Sample), + ("sample", (stages.SampleOptions.doc_limit(10),), stages.Sample), + ("union", (_make_async_pipeline(),), stages.Union), + ("unnest", ("field_name",), stages.Unnest), + ("unnest", ("field_name", "alias"), stages.Unnest), + ("unnest", (Field.of("n"), Field.of("alias")), stages.Unnest), + ("unnest", ("n", "a", stages.UnnestOptions("idx")), stages.Unnest), + ("raw_stage", ("stage_name",), stages.RawStage), + ("raw_stage", ("stage_name", Field.of("n")), stages.RawStage), + ("offset", (1,), stages.Offset), + ("limit", (1,), stages.Limit), + ("aggregate", (Field.of("n").as_("alias"),), stages.Aggregate), + ("distinct", ("field_name",), stages.Distinct), + ("distinct", (Field.of("n"), "second"), stages.Distinct), + ], +) +def test_async_pipeline_methods(method, args, result_cls): + start_ppl = _make_async_pipeline() + method_ptr = getattr(start_ppl, method) + result_ppl = method_ptr(*args) + assert result_ppl != start_ppl + assert len(start_ppl.stages) == 0 + assert len(result_ppl.stages) == 1 + assert isinstance(result_ppl.stages[0], result_cls) + + +def test_async_pipeline_aggregate_with_groups(): + start_ppl = _make_async_pipeline() + result_ppl = start_ppl.aggregate(Field.of("title"), groups=[Field.of("author")]) + assert len(start_ppl.stages) == 0 + assert len(result_ppl.stages) == 1 + assert isinstance(result_ppl.stages[0], stages.Aggregate) + assert list(result_ppl.stages[0].groups) == [Field.of("author")] + assert list(result_ppl.stages[0].accumulators) == [Field.of("title")] diff --git a/tests/unit/v1/test_async_query.py b/tests/unit/v1/test_async_query.py index 54c80e5ad..6e2aa8393 100644 --- a/tests/unit/v1/test_async_query.py +++ b/tests/unit/v1/test_async_query.py @@ -909,3 +909,22 @@ async def test_asynccollectiongroup_get_partitions_w_offset(): query = _make_async_collection_group(parent).offset(10) with pytest.raises(ValueError): [i async for i in query.get_partitions(2)] + + +def test_asyncquery_collection_pipeline_type(): + from google.cloud.firestore_v1.async_pipeline import AsyncPipeline + + client = make_async_client() + parent = client.collection("test") + query = parent._query() + ppl = query._build_pipeline(client.pipeline()) + assert isinstance(ppl, AsyncPipeline) + + +def test_asyncquery_collectiongroup_pipeline_type(): + from google.cloud.firestore_v1.async_pipeline import AsyncPipeline + + client = make_async_client() + query = client.collection_group("test") + ppl = query._build_pipeline(client.pipeline()) + assert isinstance(ppl, AsyncPipeline) diff --git a/tests/unit/v1/test_base_collection.py b/tests/unit/v1/test_base_collection.py index 22baa0c5f..9124e4d01 100644 --- a/tests/unit/v1/test_base_collection.py +++ b/tests/unit/v1/test_base_collection.py @@ -422,6 +422,21 @@ def test_basecollectionreference_end_at(mock_query): assert query == mock_query.end_at.return_value +@mock.patch("google.cloud.firestore_v1.base_query.BaseQuery", autospec=True) +def test_basecollectionreference_pipeline(mock_query): + from google.cloud.firestore_v1.base_collection import BaseCollectionReference + + with mock.patch.object(BaseCollectionReference, "_query") as _query: + _query.return_value = mock_query + + collection = _make_base_collection_reference("collection") + mock_source = mock.Mock() + pipeline = collection._build_pipeline(mock_source) + + mock_query._build_pipeline.assert_called_once_with(mock_source) + assert pipeline == mock_query._build_pipeline.return_value + + @mock.patch("random.choice") def test__auto_id(mock_rand_choice): from google.cloud.firestore_v1.base_collection import _AUTO_ID_CHARS, _auto_id diff --git a/tests/unit/v1/test_base_query.py b/tests/unit/v1/test_base_query.py index 7804b0430..4a4dac727 100644 --- a/tests/unit/v1/test_base_query.py +++ b/tests/unit/v1/test_base_query.py @@ -18,6 +18,7 @@ import pytest from tests.unit.v1._test_helpers import make_client +from google.cloud.firestore_v1 import pipeline_stages as stages def _make_base_query(*args, **kwargs): @@ -1993,6 +1994,165 @@ def test__collection_group_query_response_to_snapshot_response(): assert snapshot.update_time == response_pb._pb.document.update_time +def test__query_pipeline_decendants(): + client = make_client() + query = client.collection_group("my_col") + pipeline = query._build_pipeline(client.pipeline()) + + assert len(pipeline.stages) == 1 + stage = pipeline.stages[0] + assert isinstance(stage, stages.CollectionGroup) + assert stage.collection_id == "my_col" + + +@pytest.mark.parametrize( + "in_path,out_path", + [ + ("my_col/doc/", "/my_col/doc/"), + ("/my_col/doc", "/my_col/doc"), + ("my_col/doc/sub_col", "/my_col/doc/sub_col"), + ], +) +def test__query_pipeline_no_decendants(in_path, out_path): + client = make_client() + collection = client.collection(in_path) + query = collection._query() + pipeline = query._build_pipeline(client.pipeline()) + + assert len(pipeline.stages) == 1 + stage = pipeline.stages[0] + assert isinstance(stage, stages.Collection) + assert stage.path == out_path + + +def test__query_pipeline_composite_filter(): + from google.cloud.firestore_v1 import FieldFilter + from google.cloud.firestore_v1 import pipeline_expressions as expr + + client = make_client() + in_filter = FieldFilter("field_a", "==", "value_a") + query = client.collection("my_col").where(filter=in_filter) + with mock.patch.object( + expr.BooleanExpression, "_from_query_filter_pb" + ) as convert_mock: + pipeline = query._build_pipeline(client.pipeline()) + convert_mock.assert_called_once_with(in_filter._to_pb(), client) + assert len(pipeline.stages) == 2 + stage = pipeline.stages[1] + assert isinstance(stage, stages.Where) + assert stage.condition == convert_mock.return_value + + +def test__query_pipeline_projections(): + client = make_client() + query = client.collection("my_col").select(["field_a", "field_b.c"]) + pipeline = query._build_pipeline(client.pipeline()) + + assert len(pipeline.stages) == 2 + stage = pipeline.stages[1] + assert isinstance(stage, stages.Select) + assert len(stage.projections) == 2 + assert stage.projections[0].path == "field_a" + assert stage.projections[1].path == "field_b.c" + + +def test__query_pipeline_order_exists_multiple(): + from google.cloud.firestore_v1 import pipeline_expressions as expr + + client = make_client() + query = client.collection("my_col").order_by("field_a").order_by("field_b") + pipeline = query._build_pipeline(client.pipeline()) + + # should have collection, where, and sort + # we're interested in where + assert len(pipeline.stages) == 3 + where_stage = pipeline.stages[1] + assert isinstance(where_stage, stages.Where) + # should have and with both orderings + assert isinstance(where_stage.condition, expr.And) + assert len(where_stage.condition.params) == 2 + operands = [p for p in where_stage.condition.params] + assert operands[0].name == "exists" + assert operands[0].params[0].path == "field_a" + assert operands[1].name == "exists" + assert operands[1].params[0].path == "field_b" + + +def test__query_pipeline_order_exists_single(): + client = make_client() + query_single = client.collection("my_col").order_by("field_c") + pipeline_single = query_single._build_pipeline(client.pipeline()) + + # should have collection, where, and sort + # we're interested in where + assert len(pipeline_single.stages) == 3 + where_stage_single = pipeline_single.stages[1] + assert isinstance(where_stage_single, stages.Where) + assert where_stage_single.condition.name == "exists" + assert where_stage_single.condition.params[0].path == "field_c" + + +def test__query_pipeline_order_sorts(): + from google.cloud.firestore_v1 import pipeline_expressions as expr + from google.cloud.firestore_v1.base_query import BaseQuery + + client = make_client() + query = ( + client.collection("my_col") + .order_by("field_a", direction=BaseQuery.ASCENDING) + .order_by("field_b", direction=BaseQuery.DESCENDING) + ) + pipeline = query._build_pipeline(client.pipeline()) + + assert len(pipeline.stages) == 3 + sort_stage = pipeline.stages[2] + assert isinstance(sort_stage, stages.Sort) + assert len(sort_stage.orders) == 2 + assert isinstance(sort_stage.orders[0], expr.Ordering) + assert sort_stage.orders[0].expr.path == "field_a" + assert sort_stage.orders[0].order_dir == expr.Ordering.Direction.ASCENDING + assert isinstance(sort_stage.orders[1], expr.Ordering) + assert sort_stage.orders[1].expr.path == "field_b" + assert sort_stage.orders[1].order_dir == expr.Ordering.Direction.DESCENDING + + +def test__query_pipeline_unsupported(): + client = make_client() + query_start = client.collection("my_col").start_at({"field_a": "value"}) + with pytest.raises(NotImplementedError, match="cursors"): + query_start._build_pipeline(client.pipeline()) + + query_end = client.collection("my_col").end_at({"field_a": "value"}) + with pytest.raises(NotImplementedError, match="cursors"): + query_end._build_pipeline(client.pipeline()) + + query_limit_last = client.collection("my_col").limit_to_last(10) + with pytest.raises(NotImplementedError, match="limit_to_last"): + query_limit_last._build_pipeline(client.pipeline()) + + +def test__query_pipeline_limit(): + client = make_client() + query = client.collection("my_col").limit(15) + pipeline = query._build_pipeline(client.pipeline()) + + assert len(pipeline.stages) == 2 + stage = pipeline.stages[1] + assert isinstance(stage, stages.Limit) + assert stage.limit == 15 + + +def test__query_pipeline_offset(): + client = make_client() + query = client.collection("my_col").offset(5) + pipeline = query._build_pipeline(client.pipeline()) + + assert len(pipeline.stages) == 2 + stage = pipeline.stages[1] + assert isinstance(stage, stages.Offset) + assert stage.offset == 5 + + def _make_order_pb(field_path, direction): from google.cloud.firestore_v1.types import query diff --git a/tests/unit/v1/test_client.py b/tests/unit/v1/test_client.py index df3ae15b4..9d0199f92 100644 --- a/tests/unit/v1/test_client.py +++ b/tests/unit/v1/test_client.py @@ -648,6 +648,18 @@ def test_client_transaction(database): assert transaction._id is None +@pytest.mark.parametrize("database", [None, DEFAULT_DATABASE, "somedb"]) +def test_client_pipeline(database): + from google.cloud.firestore_v1.pipeline import Pipeline + from google.cloud.firestore_v1.pipeline_source import PipelineSource + + client = _make_default_client(database=database) + ppl = client.pipeline() + assert client._pipeline_cls == Pipeline + assert isinstance(ppl, PipelineSource) + assert ppl.client == client + + def _make_batch_response(**kwargs): from google.cloud.firestore_v1.types import firestore diff --git a/tests/unit/v1/test_collection.py b/tests/unit/v1/test_collection.py index da91651b9..156b314aa 100644 --- a/tests/unit/v1/test_collection.py +++ b/tests/unit/v1/test_collection.py @@ -510,6 +510,21 @@ def test_stream_w_read_time(query_class): ) +def test_collectionreference_pipeline(): + from tests.unit.v1 import _test_helpers + from google.cloud.firestore_v1.pipeline import Pipeline + from google.cloud.firestore_v1.pipeline_stages import Collection + + client = _test_helpers.make_client() + collection = _make_collection_reference("collection", client=client) + pipeline = collection._build_pipeline(client.pipeline()) + assert isinstance(pipeline, Pipeline) + # should have single "Collection" stage + assert len(pipeline.stages) == 1 + assert isinstance(pipeline.stages[0], Collection) + assert pipeline.stages[0].path == "/collection" + + @mock.patch("google.cloud.firestore_v1.collection.Watch", autospec=True) def test_on_snapshot(watch): collection = _make_collection_reference("collection") diff --git a/tests/unit/v1/test_pipeline.py b/tests/unit/v1/test_pipeline.py new file mode 100644 index 000000000..10509cafb --- /dev/null +++ b/tests/unit/v1/test_pipeline.py @@ -0,0 +1,430 @@ +# Copyright 2025 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +import mock +import pytest + +from google.cloud.firestore_v1 import pipeline_stages as stages +from google.cloud.firestore_v1.pipeline_expressions import Field + +from tests.unit.v1._test_helpers import make_client + + +def _make_pipeline(*args, client=mock.Mock()): + from google.cloud.firestore_v1.pipeline import Pipeline + + return Pipeline._create_with_stages(client, *args) + + +def test_ctor(): + from google.cloud.firestore_v1.pipeline import Pipeline + + client = object() + instance = Pipeline(client) + assert instance._client == client + assert len(instance.stages) == 0 + + +def test_create(): + from google.cloud.firestore_v1.pipeline import Pipeline + + client = object() + stages = [object() for i in range(10)] + instance = Pipeline._create_with_stages(client, *stages) + assert instance._client == client + assert len(instance.stages) == 10 + assert instance.stages[0] == stages[0] + assert instance.stages[-1] == stages[-1] + + +def test_pipeline_repr_empty(): + ppl = _make_pipeline() + repr_str = repr(ppl) + assert repr_str == "Pipeline()" + + +def test_pipeline_repr_single_stage(): + stage = mock.Mock() + stage.__repr__ = lambda x: "SingleStage" + ppl = _make_pipeline(stage) + repr_str = repr(ppl) + assert repr_str == "Pipeline(SingleStage)" + + +def test_pipeline_repr_multiple_stage(): + stage_1 = stages.Collection("path") + stage_2 = stages.RawStage("second", 2) + stage_3 = stages.RawStage("third", 3) + ppl = _make_pipeline(stage_1, stage_2, stage_3) + repr_str = repr(ppl) + assert repr_str == ( + "Pipeline(\n" + " Collection(path='/path'),\n" + " RawStage(name='second'),\n" + " RawStage(name='third')\n" + ")" + ) + + +def test_pipeline_repr_long(): + num_stages = 100 + stage_list = [stages.RawStage("custom", i) for i in range(num_stages)] + ppl = _make_pipeline(*stage_list) + repr_str = repr(ppl) + assert repr_str.count("RawStage") == num_stages + assert repr_str.count("\n") == num_stages + 1 + + +def test_pipeline__to_pb(): + from google.cloud.firestore_v1.types.pipeline import StructuredPipeline + + stage_1 = stages.RawStage("first") + stage_2 = stages.RawStage("second") + ppl = _make_pipeline(stage_1, stage_2) + pb = ppl._to_pb() + assert isinstance(pb, StructuredPipeline) + assert pb.pipeline.stages[0] == stage_1._to_pb() + assert pb.pipeline.stages[1] == stage_2._to_pb() + + +def test_pipeline__to_pb_with_options(): + from google.cloud.firestore_v1.types.pipeline import StructuredPipeline + from google.cloud.firestore_v1.types.document import Value + + ppl = _make_pipeline() + options = {"option_1": Value(integer_value=1)} + pb = ppl._to_pb(**options) + assert isinstance(pb, StructuredPipeline) + assert pb.options["option_1"].integer_value == 1 + + +def test_pipeline_append(): + """append should create a new pipeline with the additional stage""" + + stage_1 = stages.RawStage("first") + ppl_1 = _make_pipeline(stage_1, client=object()) + stage_2 = stages.RawStage("second") + ppl_2 = ppl_1._append(stage_2) + assert ppl_1 != ppl_2 + assert len(ppl_1.stages) == 1 + assert len(ppl_2.stages) == 2 + assert ppl_2.stages[0] == stage_1 + assert ppl_2.stages[1] == stage_2 + assert ppl_1._client == ppl_2._client + assert isinstance(ppl_2, type(ppl_1)) + + +def test_pipeline_stream_empty(): + """ + test stream pipeline with mocked empty response + """ + from google.cloud.firestore_v1.types import ExecutePipelineResponse + from google.cloud.firestore_v1.types import ExecutePipelineRequest + + client = mock.Mock() + client.project = "A" + client._database = "B" + mock_rpc = client._firestore_api.execute_pipeline + mock_rpc.return_value = [ExecutePipelineResponse()] + ppl_1 = _make_pipeline(stages.RawStage("s"), client=client) + + results = list(ppl_1.stream()) + assert results == [] + assert mock_rpc.call_count == 1 + request = mock_rpc.call_args[0][0] + assert isinstance(request, ExecutePipelineRequest) + assert request.structured_pipeline == ppl_1._to_pb() + assert request.database == "projects/A/databases/B" + + +def test_pipeline_stream_no_doc_ref(): + """ + test stream pipeline with no doc ref + """ + from google.cloud.firestore_v1.types import Document + from google.cloud.firestore_v1.types import ExecutePipelineResponse + from google.cloud.firestore_v1.types import ExecutePipelineRequest + from google.cloud.firestore_v1.pipeline_result import PipelineResult + + client = mock.Mock() + client.project = "A" + client._database = "B" + mock_rpc = client._firestore_api.execute_pipeline + mock_rpc.return_value = [ + ExecutePipelineResponse(results=[Document()], execution_time={"seconds": 9}) + ] + ppl_1 = _make_pipeline(stages.RawStage("s"), client=client) + + results = list(ppl_1.stream()) + assert len(results) == 1 + assert mock_rpc.call_count == 1 + request = mock_rpc.call_args[0][0] + assert isinstance(request, ExecutePipelineRequest) + assert request.structured_pipeline == ppl_1._to_pb() + assert request.database == "projects/A/databases/B" + + response = results[0] + assert isinstance(response, PipelineResult) + assert response.ref is None + assert response.id is None + assert response.create_time is None + assert response.update_time is None + assert response.execution_time.seconds == 9 + assert response.data() == {} + + +def test_pipeline_stream_populated(): + """ + test stream pipeline with fully populated doc ref + """ + from google.cloud.firestore_v1.types import Document + from google.cloud.firestore_v1.types import ExecutePipelineResponse + from google.cloud.firestore_v1.types import ExecutePipelineRequest + from google.cloud.firestore_v1.types import Value + from google.cloud.firestore_v1.document import DocumentReference + from google.cloud.firestore_v1.pipeline_result import PipelineResult + + real_client = make_client() + client = mock.Mock() + client.project = "A" + client._database = "B" + client.document = real_client.document + mock_rpc = client._firestore_api.execute_pipeline + + mock_rpc.return_value = [ + ExecutePipelineResponse( + results=[ + Document( + name="test/my_doc", + create_time={"seconds": 1}, + update_time={"seconds": 2}, + fields={"key": Value(string_value="str_val")}, + ) + ], + execution_time={"seconds": 9}, + ) + ] + ppl_1 = _make_pipeline(client=client) + + results = list(ppl_1.stream()) + assert len(results) == 1 + assert mock_rpc.call_count == 1 + request = mock_rpc.call_args[0][0] + assert isinstance(request, ExecutePipelineRequest) + assert request.structured_pipeline == ppl_1._to_pb() + assert request.database == "projects/A/databases/B" + assert request.transaction == b"" + + response = results[0] + assert isinstance(response, PipelineResult) + assert isinstance(response.ref, DocumentReference) + assert response.ref.path == "test/my_doc" + assert response.id == "my_doc" + assert response.create_time.seconds == 1 + assert response.update_time.seconds == 2 + assert response.execution_time.seconds == 9 + assert response.data() == {"key": "str_val"} + + +def test_pipeline_stream_multiple(): + """ + test stream pipeline with multiple docs and responses + """ + from google.cloud.firestore_v1.types import Document + from google.cloud.firestore_v1.types import ExecutePipelineResponse + from google.cloud.firestore_v1.types import ExecutePipelineRequest + from google.cloud.firestore_v1.types import Value + from google.cloud.firestore_v1.pipeline_result import PipelineResult + + real_client = make_client() + client = mock.Mock() + client.project = "A" + client._database = "B" + client.document = real_client.document + mock_rpc = client._firestore_api.execute_pipeline + + mock_rpc.return_value = [ + ExecutePipelineResponse( + results=[ + Document(fields={"key": Value(integer_value=0)}), + Document(fields={"key": Value(integer_value=1)}), + ], + execution_time={"seconds": 0}, + ), + ExecutePipelineResponse( + results=[ + Document(fields={"key": Value(integer_value=2)}), + Document(fields={"key": Value(integer_value=3)}), + ], + execution_time={"seconds": 1}, + ), + ] + ppl_1 = _make_pipeline(client=client) + + results = list(ppl_1.stream()) + assert len(results) == 4 + assert mock_rpc.call_count == 1 + request = mock_rpc.call_args[0][0] + assert isinstance(request, ExecutePipelineRequest) + assert request.structured_pipeline == ppl_1._to_pb() + assert request.database == "projects/A/databases/B" + + for idx, response in enumerate(results): + assert isinstance(response, PipelineResult) + assert response.data() == {"key": idx} + + +def test_pipeline_stream_with_transaction(): + """ + test stream pipeline with fully populated doc ref + """ + from google.cloud.firestore_v1.types import ExecutePipelineResponse + from google.cloud.firestore_v1.types import ExecutePipelineRequest + from google.cloud.firestore_v1.transaction import Transaction + + client = mock.Mock() + client.project = "A" + client._database = "B" + mock_rpc = client._firestore_api.execute_pipeline + + transaction = Transaction(client) + transaction._id = b"123" + + mock_rpc.return_value = [ExecutePipelineResponse()] + ppl_1 = _make_pipeline(client=client) + + list(ppl_1.stream(transaction=transaction)) + assert mock_rpc.call_count == 1 + request = mock_rpc.call_args[0][0] + assert isinstance(request, ExecutePipelineRequest) + assert request.structured_pipeline == ppl_1._to_pb() + assert request.database == "projects/A/databases/B" + assert request.transaction == b"123" + + +def test_pipeline_stream_with_read_time(): + """ + test stream pipeline with read_time + """ + import datetime + + from google.cloud.firestore_v1.types import ExecutePipelineResponse + from google.cloud.firestore_v1.types import ExecutePipelineRequest + + client = mock.Mock() + client.project = "A" + client._database = "B" + mock_rpc = client._firestore_api.execute_pipeline + + read_time = datetime.datetime.now(tz=datetime.timezone.utc) + + mock_rpc.return_value = [ExecutePipelineResponse()] + ppl_1 = _make_pipeline(client=client) + + list(ppl_1.stream(read_time=read_time)) + assert mock_rpc.call_count == 1 + request = mock_rpc.call_args[0][0] + assert isinstance(request, ExecutePipelineRequest) + assert request.structured_pipeline == ppl_1._to_pb() + assert request.database == "projects/A/databases/B" + assert request.read_time == read_time + + +def test_pipeline_execute_stream_equivalence(): + """ + Pipeline.execute should provide same results from pipeline.stream, as a list + """ + from google.cloud.firestore_v1.types import Document + from google.cloud.firestore_v1.types import ExecutePipelineResponse + from google.cloud.firestore_v1.types import Value + + real_client = make_client() + client = mock.Mock() + client.project = "A" + client._database = "B" + client.document = real_client.document + mock_rpc = client._firestore_api.execute_pipeline + + mock_rpc.return_value = [ + ExecutePipelineResponse( + results=[ + Document( + name="test/my_doc", + fields={"key": Value(string_value="str_val")}, + ) + ], + ) + ] + ppl_1 = _make_pipeline(client=client) + + stream_results = list(ppl_1.stream()) + execute_results = ppl_1.execute() + assert stream_results == execute_results + assert stream_results[0].data()["key"] == "str_val" + assert execute_results[0].data()["key"] == "str_val" + + +@pytest.mark.parametrize( + "method,args,result_cls", + [ + ("add_fields", (Field.of("n"),), stages.AddFields), + ("remove_fields", ("name",), stages.RemoveFields), + ("remove_fields", (Field.of("n"),), stages.RemoveFields), + ("select", ("name",), stages.Select), + ("select", (Field.of("n"),), stages.Select), + ("where", (Field.of("n").exists(),), stages.Where), + ("find_nearest", ("name", [0.1], "cosine"), stages.FindNearest), + ( + "find_nearest", + ("name", [0.1], "cosine", stages.FindNearestOptions(10)), + stages.FindNearest, + ), + ("replace_with", ("name",), stages.ReplaceWith), + ("replace_with", (Field.of("n"),), stages.ReplaceWith), + ("sort", (Field.of("n").descending(),), stages.Sort), + ("sort", (Field.of("n").descending(), Field.of("m").ascending()), stages.Sort), + ("sample", (10,), stages.Sample), + ("sample", (stages.SampleOptions.doc_limit(10),), stages.Sample), + ("union", (_make_pipeline(),), stages.Union), + ("unnest", ("field_name",), stages.Unnest), + ("unnest", ("field_name", "alias"), stages.Unnest), + ("unnest", (Field.of("n"), Field.of("alias")), stages.Unnest), + ("unnest", ("n", "a", stages.UnnestOptions("idx")), stages.Unnest), + ("raw_stage", ("stage_name",), stages.RawStage), + ("raw_stage", ("stage_name", Field.of("n")), stages.RawStage), + ("offset", (1,), stages.Offset), + ("limit", (1,), stages.Limit), + ("aggregate", (Field.of("n").as_("alias"),), stages.Aggregate), + ("distinct", ("field_name",), stages.Distinct), + ("distinct", (Field.of("n"), "second"), stages.Distinct), + ], +) +def test_pipeline_methods(method, args, result_cls): + start_ppl = _make_pipeline() + method_ptr = getattr(start_ppl, method) + result_ppl = method_ptr(*args) + assert result_ppl != start_ppl + assert len(start_ppl.stages) == 0 + assert len(result_ppl.stages) == 1 + assert isinstance(result_ppl.stages[0], result_cls) + + +def test_pipeline_aggregate_with_groups(): + start_ppl = _make_pipeline() + result_ppl = start_ppl.aggregate(Field.of("title"), groups=[Field.of("author")]) + assert len(start_ppl.stages) == 0 + assert len(result_ppl.stages) == 1 + assert isinstance(result_ppl.stages[0], stages.Aggregate) + assert list(result_ppl.stages[0].groups) == [Field.of("author")] + assert list(result_ppl.stages[0].accumulators) == [Field.of("title")] diff --git a/tests/unit/v1/test_pipeline_expressions.py b/tests/unit/v1/test_pipeline_expressions.py new file mode 100644 index 000000000..e2c6dcd0f --- /dev/null +++ b/tests/unit/v1/test_pipeline_expressions.py @@ -0,0 +1,1567 @@ +# Copyright 2025 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import mock +import math +import datetime + +from google.cloud.firestore_v1 import _helpers +from google.cloud.firestore_v1.types import document as document_pb +from google.cloud.firestore_v1.types import query as query_pb +from google.cloud.firestore_v1.types.document import Value +from google.cloud.firestore_v1.vector import Vector +from google.cloud.firestore_v1._helpers import GeoPoint +import google.cloud.firestore_v1.pipeline_expressions as expr +from google.cloud.firestore_v1.pipeline_expressions import BooleanExpression +from google.cloud.firestore_v1.pipeline_expressions import Expression +from google.cloud.firestore_v1.pipeline_expressions import Constant +from google.cloud.firestore_v1.pipeline_expressions import Field +from google.cloud.firestore_v1.pipeline_expressions import Ordering + + +@pytest.fixture +def mock_client(): + client = mock.Mock(spec=["_database_string", "collection"]) + client._database_string = "projects/p/databases/d" + return client + + +class TestOrdering: + @pytest.mark.parametrize( + "direction_arg,expected_direction", + [ + ("ASCENDING", Ordering.Direction.ASCENDING), + ("DESCENDING", Ordering.Direction.DESCENDING), + ("ascending", Ordering.Direction.ASCENDING), + ("descending", Ordering.Direction.DESCENDING), + (Ordering.Direction.ASCENDING, Ordering.Direction.ASCENDING), + (Ordering.Direction.DESCENDING, Ordering.Direction.DESCENDING), + ], + ) + def test_ctor(self, direction_arg, expected_direction): + instance = Ordering("field1", direction_arg) + assert isinstance(instance.expr, Field) + assert instance.expr.path == "field1" + assert instance.order_dir == expected_direction + + def test_repr(self): + field_expr = Field.of("field1") + instance = Ordering(field_expr, "ASCENDING") + repr_str = repr(instance) + assert repr_str == "Field.of('field1').ascending()" + + instance = Ordering(field_expr, "DESCENDING") + repr_str = repr(instance) + assert repr_str == "Field.of('field1').descending()" + + def test_to_pb(self): + field_expr = Field.of("field1") + instance = Ordering(field_expr, "ASCENDING") + result = instance._to_pb() + assert result.map_value.fields["expression"].field_reference_value == "field1" + assert result.map_value.fields["direction"].string_value == "ascending" + + instance = Ordering(field_expr, "DESCENDING") + result = instance._to_pb() + assert result.map_value.fields["expression"].field_reference_value == "field1" + assert result.map_value.fields["direction"].string_value == "descending" + + +class TestConstant: + @pytest.mark.parametrize( + "input_val, to_pb_val", + [ + ("test", Value(string_value="test")), + ("", Value(string_value="")), + (10, Value(integer_value=10)), + (0, Value(integer_value=0)), + (10.0, Value(double_value=10)), + (0.0, Value(double_value=0)), + (True, Value(boolean_value=True)), + (b"test", Value(bytes_value=b"test")), + (None, Value(null_value=0)), + ( + datetime.datetime(2025, 5, 12), + Value(timestamp_value={"seconds": 1747008000}), + ), + (GeoPoint(1, 2), Value(geo_point_value={"latitude": 1, "longitude": 2})), + ( + Vector([1.0, 2.0]), + Value( + map_value={ + "fields": { + "__type__": Value(string_value="__vector__"), + "value": Value( + array_value={ + "values": [Value(double_value=v) for v in [1, 2]], + } + ), + } + } + ), + ), + ], + ) + def test_to_pb(self, input_val, to_pb_val): + instance = Constant.of(input_val) + assert instance._to_pb() == to_pb_val + + @pytest.mark.parametrize("input", [float("nan"), math.nan]) + def test_nan_to_pb(self, input): + instance = Constant.of(input) + assert repr(instance) == "Constant.of(math.nan)" + pb_val = instance._to_pb() + assert math.isnan(pb_val.double_value) + + @pytest.mark.parametrize( + "input_val,expected", + [ + ("test", "Constant.of('test')"), + ("", "Constant.of('')"), + (10, "Constant.of(10)"), + (0, "Constant.of(0)"), + (10.0, "Constant.of(10.0)"), + (0.0, "Constant.of(0.0)"), + (True, "Constant.of(True)"), + (b"test", "Constant.of(b'test')"), + (None, "Constant.of(None)"), + ( + datetime.datetime(2025, 5, 12), + "Constant.of(datetime.datetime(2025, 5, 12, 0, 0))", + ), + (GeoPoint(1, 2), "Constant.of(GeoPoint(latitude=1, longitude=2))"), + ([1, 2, 3], "Constant.of([1, 2, 3])"), + ({"a": "b"}, "Constant.of({'a': 'b'})"), + (Vector([1.0, 2.0]), "Constant.of(Vector<1.0, 2.0>)"), + ], + ) + def test_repr(self, input_val, expected): + instance = Constant.of(input_val) + repr_string = repr(instance) + assert repr_string == expected + + @pytest.mark.parametrize( + "first,second,expected", + [ + (Constant.of(1), Constant.of(2), False), + (Constant.of(1), Constant.of(1), True), + (Constant.of(1), 1, True), + (Constant.of(1), 2, False), + (Constant.of("1"), 1, False), + (Constant.of("1"), "1", True), + (Constant.of(None), Constant.of(0), False), + (Constant.of(None), Constant.of(None), True), + (Constant.of([1, 2, 3]), Constant.of([1, 2, 3]), True), + (Constant.of([1, 2, 3]), Constant.of([1, 2]), False), + (Constant.of([1, 2, 3]), [1, 2, 3], True), + (Constant.of([1, 2, 3]), object(), False), + ], + ) + def test_equality(self, first, second, expected): + assert (first == second) is expected + + +class TestSelectable: + """ + contains tests for each Expression class that derives from Selectable + """ + + def test_ctor(self): + """ + Base class should be abstract + """ + with pytest.raises(TypeError): + expr.Selectable() + + def test_value_from_selectables(self): + selectable_list = [ + Field.of("field1"), + Field.of("field2").as_("alias2"), + ] + result = expr.Selectable._value_from_selectables(*selectable_list) + assert len(result.map_value.fields) == 2 + assert result.map_value.fields["field1"].field_reference_value == "field1" + assert result.map_value.fields["alias2"].field_reference_value == "field2" + + @pytest.mark.parametrize( + "first,second,expected", + [ + (Field.of("field1"), Field.of("field1"), True), + (Field.of("field1"), Field.of("field2"), False), + (Field.of(None), object(), False), + (Field.of("f").as_("a"), Field.of("f").as_("a"), True), + (Field.of("one").as_("a"), Field.of("two").as_("a"), False), + (Field.of("f").as_("one"), Field.of("f").as_("two"), False), + (Field.of("field"), Field.of("field").as_("alias"), False), + (Field.of("field").as_("alias"), Field.of("field"), False), + ], + ) + def test_equality(self, first, second, expected): + assert (first == second) is expected + + class TestField: + def test_repr(self): + instance = Field.of("field1") + repr_string = repr(instance) + assert repr_string == "Field.of('field1')" + + def test_of(self): + instance = Field.of("field1") + assert instance.path == "field1" + + def test_to_pb(self): + instance = Field.of("field1") + result = instance._to_pb() + assert result.field_reference_value == "field1" + + def test_to_map(self): + instance = Field.of("field1") + result = instance._to_map() + assert result[0] == "field1" + assert result[1] == Value(field_reference_value="field1") + + class TestAliasedExpression: + def test_repr(self): + instance = Field.of("field1").as_("alias1") + assert repr(instance) == "Field.of('field1').as_('alias1')" + + def test_ctor(self): + arg = Field.of("field1") + alias = "alias1" + instance = expr.AliasedExpression(arg, alias) + assert instance.expr == arg + assert instance.alias == alias + + def test_to_pb(self): + arg = Field.of("field1") + alias = "alias1" + instance = expr.AliasedExpression(arg, alias) + result = instance._to_pb() + assert result.map_value.fields.get("alias1") == arg._to_pb() + + def test_to_map(self): + instance = Field.of("field1").as_("alias1") + result = instance._to_map() + assert result[0] == "alias1" + assert result[1] == Value(field_reference_value="field1") + + +class TestBooleanExpression: + def test__from_query_filter_pb_composite_filter_or(self, mock_client): + """ + test composite OR filters + + should create an or statement, made up of ands checking of existance of relevant fields + """ + filter1_pb = query_pb.StructuredQuery.FieldFilter( + field=query_pb.StructuredQuery.FieldReference(field_path="field1"), + op=query_pb.StructuredQuery.FieldFilter.Operator.EQUAL, + value=_helpers.encode_value("val1"), + ) + filter2_pb = query_pb.StructuredQuery.UnaryFilter( + field=query_pb.StructuredQuery.FieldReference(field_path="field2"), + op=query_pb.StructuredQuery.UnaryFilter.Operator.IS_NULL, + ) + + composite_pb = query_pb.StructuredQuery.CompositeFilter( + op=query_pb.StructuredQuery.CompositeFilter.Operator.OR, + filters=[ + query_pb.StructuredQuery.Filter(field_filter=filter1_pb), + query_pb.StructuredQuery.Filter(unary_filter=filter2_pb), + ], + ) + wrapped_filter_pb = query_pb.StructuredQuery.Filter( + composite_filter=composite_pb + ) + + result = BooleanExpression._from_query_filter_pb(wrapped_filter_pb, mock_client) + + # should include existance checks + field1 = Field.of("field1") + field2 = Field.of("field2") + expected_cond1 = expr.And(field1.exists(), field1.equal(Constant("val1"))) + expected_cond2 = expr.And(field2.exists(), field2.equal(None)) + expected = expr.Or(expected_cond1, expected_cond2) + + assert repr(result) == repr(expected) + + def test__from_query_filter_pb_composite_filter_and(self, mock_client): + """ + test composite AND filters + + should create an and statement, made up of ands checking of existance of relevant fields + """ + filter1_pb = query_pb.StructuredQuery.FieldFilter( + field=query_pb.StructuredQuery.FieldReference(field_path="field1"), + op=query_pb.StructuredQuery.FieldFilter.Operator.GREATER_THAN, + value=_helpers.encode_value(100), + ) + filter2_pb = query_pb.StructuredQuery.FieldFilter( + field=query_pb.StructuredQuery.FieldReference(field_path="field2"), + op=query_pb.StructuredQuery.FieldFilter.Operator.LESS_THAN, + value=_helpers.encode_value(200), + ) + + composite_pb = query_pb.StructuredQuery.CompositeFilter( + op=query_pb.StructuredQuery.CompositeFilter.Operator.AND, + filters=[ + query_pb.StructuredQuery.Filter(field_filter=filter1_pb), + query_pb.StructuredQuery.Filter(field_filter=filter2_pb), + ], + ) + wrapped_filter_pb = query_pb.StructuredQuery.Filter( + composite_filter=composite_pb + ) + + result = BooleanExpression._from_query_filter_pb(wrapped_filter_pb, mock_client) + + # should include existance checks + field1 = Field.of("field1") + field2 = Field.of("field2") + expected_cond1 = expr.And(field1.exists(), field1.greater_than(Constant(100))) + expected_cond2 = expr.And(field2.exists(), field2.less_than(Constant(200))) + expected = expr.And(expected_cond1, expected_cond2) + assert repr(result) == repr(expected) + + def test__from_query_filter_pb_composite_filter_nested(self, mock_client): + """ + test composite filter with complex nested checks + """ + # OR (field1 == "val1", AND(field2 > 10, field3 IS NOT NULL)) + filter1_pb = query_pb.StructuredQuery.FieldFilter( + field=query_pb.StructuredQuery.FieldReference(field_path="field1"), + op=query_pb.StructuredQuery.FieldFilter.Operator.EQUAL, + value=_helpers.encode_value("val1"), + ) + filter2_pb = query_pb.StructuredQuery.FieldFilter( + field=query_pb.StructuredQuery.FieldReference(field_path="field2"), + op=query_pb.StructuredQuery.FieldFilter.Operator.GREATER_THAN, + value=_helpers.encode_value(10), + ) + filter3_pb = query_pb.StructuredQuery.UnaryFilter( + field=query_pb.StructuredQuery.FieldReference(field_path="field3"), + op=query_pb.StructuredQuery.UnaryFilter.Operator.IS_NOT_NULL, + ) + inner_and_pb = query_pb.StructuredQuery.CompositeFilter( + op=query_pb.StructuredQuery.CompositeFilter.Operator.AND, + filters=[ + query_pb.StructuredQuery.Filter(field_filter=filter2_pb), + query_pb.StructuredQuery.Filter(unary_filter=filter3_pb), + ], + ) + outer_or_pb = query_pb.StructuredQuery.CompositeFilter( + op=query_pb.StructuredQuery.CompositeFilter.Operator.OR, + filters=[ + query_pb.StructuredQuery.Filter(field_filter=filter1_pb), + query_pb.StructuredQuery.Filter(composite_filter=inner_and_pb), + ], + ) + wrapped_filter_pb = query_pb.StructuredQuery.Filter( + composite_filter=outer_or_pb + ) + + result = BooleanExpression._from_query_filter_pb(wrapped_filter_pb, mock_client) + + field1 = Field.of("field1") + field2 = Field.of("field2") + field3 = Field.of("field3") + expected_cond1 = expr.And(field1.exists(), field1.equal(Constant("val1"))) + expected_cond2 = expr.And(field2.exists(), field2.greater_than(Constant(10))) + expected_cond3 = expr.And(field3.exists(), expr.Not(field3.equal(None))) + expected_inner_and = expr.And(expected_cond2, expected_cond3) + expected_outer_or = expr.Or(expected_cond1, expected_inner_and) + + assert repr(result) == repr(expected_outer_or) + + def test__from_query_filter_pb_composite_filter_unknown_op(self, mock_client): + """ + check composite filter with unsupported operator type + """ + filter1_pb = query_pb.StructuredQuery.FieldFilter( + field=query_pb.StructuredQuery.FieldReference(field_path="field1"), + op=query_pb.StructuredQuery.FieldFilter.Operator.EQUAL, + value=_helpers.encode_value("val1"), + ) + composite_pb = query_pb.StructuredQuery.CompositeFilter( + op=query_pb.StructuredQuery.CompositeFilter.Operator.OPERATOR_UNSPECIFIED, + filters=[query_pb.StructuredQuery.Filter(field_filter=filter1_pb)], + ) + wrapped_filter_pb = query_pb.StructuredQuery.Filter( + composite_filter=composite_pb + ) + + with pytest.raises(TypeError, match="Unexpected CompositeFilter operator type"): + BooleanExpression._from_query_filter_pb(wrapped_filter_pb, mock_client) + + @pytest.mark.parametrize( + "op_enum, expected_expr_func", + [ + ( + query_pb.StructuredQuery.UnaryFilter.Operator.IS_NAN, + lambda x: x.equal(float("nan")), + ), + ( + query_pb.StructuredQuery.UnaryFilter.Operator.IS_NOT_NAN, + lambda x: expr.Not(x.equal(float("nan"))), + ), + ( + query_pb.StructuredQuery.UnaryFilter.Operator.IS_NULL, + lambda x: x.equal(None), + ), + ( + query_pb.StructuredQuery.UnaryFilter.Operator.IS_NOT_NULL, + lambda x: expr.Not(x.equal(None)), + ), + ], + ) + def test__from_query_filter_pb_unary_filter( + self, mock_client, op_enum, expected_expr_func + ): + """ + test supported unary filters + """ + field_path = "unary_field" + filter_pb = query_pb.StructuredQuery.UnaryFilter( + field=query_pb.StructuredQuery.FieldReference(field_path=field_path), + op=op_enum, + ) + wrapped_filter_pb = query_pb.StructuredQuery.Filter(unary_filter=filter_pb) + + result = BooleanExpression._from_query_filter_pb(wrapped_filter_pb, mock_client) + + field_expr_inst = Field.of(field_path) + expected_condition = expected_expr_func(field_expr_inst) + # should include existance checks + expected = expr.And(field_expr_inst.exists(), expected_condition) + + assert repr(result) == repr(expected) + + def test__from_query_filter_pb_unary_filter_unknown_op(self, mock_client): + """ + check unary filter with unsupported operator type + """ + field_path = "unary_field" + filter_pb = query_pb.StructuredQuery.UnaryFilter( + field=query_pb.StructuredQuery.FieldReference(field_path=field_path), + op=query_pb.StructuredQuery.UnaryFilter.Operator.OPERATOR_UNSPECIFIED, # Unknown op + ) + wrapped_filter_pb = query_pb.StructuredQuery.Filter(unary_filter=filter_pb) + + with pytest.raises(TypeError, match="Unexpected UnaryFilter operator type"): + BooleanExpression._from_query_filter_pb(wrapped_filter_pb, mock_client) + + @pytest.mark.parametrize( + "op_enum, value, expected_expr_func", + [ + ( + query_pb.StructuredQuery.FieldFilter.Operator.LESS_THAN, + 10, + Expression.less_than, + ), + ( + query_pb.StructuredQuery.FieldFilter.Operator.LESS_THAN_OR_EQUAL, + 10, + Expression.less_than_or_equal, + ), + ( + query_pb.StructuredQuery.FieldFilter.Operator.GREATER_THAN, + 10, + Expression.greater_than, + ), + ( + query_pb.StructuredQuery.FieldFilter.Operator.GREATER_THAN_OR_EQUAL, + 10, + Expression.greater_than_or_equal, + ), + (query_pb.StructuredQuery.FieldFilter.Operator.EQUAL, 10, Expression.equal), + ( + query_pb.StructuredQuery.FieldFilter.Operator.NOT_EQUAL, + 10, + Expression.not_equal, + ), + ( + query_pb.StructuredQuery.FieldFilter.Operator.ARRAY_CONTAINS, + 10, + Expression.array_contains, + ), + ( + query_pb.StructuredQuery.FieldFilter.Operator.ARRAY_CONTAINS_ANY, + [10, 20], + Expression.array_contains_any, + ), + ( + query_pb.StructuredQuery.FieldFilter.Operator.IN, + [10, 20], + Expression.equal_any, + ), + ( + query_pb.StructuredQuery.FieldFilter.Operator.NOT_IN, + [10, 20], + Expression.not_equal_any, + ), + ], + ) + def test__from_query_filter_pb_field_filter( + self, mock_client, op_enum, value, expected_expr_func + ): + """ + test supported field filters + """ + field_path = "test_field" + value_pb = _helpers.encode_value(value) + filter_pb = query_pb.StructuredQuery.FieldFilter( + field=query_pb.StructuredQuery.FieldReference(field_path=field_path), + op=op_enum, + value=value_pb, + ) + wrapped_filter_pb = query_pb.StructuredQuery.Filter(field_filter=filter_pb) + + result = BooleanExpression._from_query_filter_pb(wrapped_filter_pb, mock_client) + + field_expr = Field.of(field_path) + # convert values into constants + value = ( + [Constant(e) for e in value] if isinstance(value, list) else Constant(value) + ) + expected_condition = expected_expr_func(field_expr, value) + # should include existance checks + expected = expr.And(field_expr.exists(), expected_condition) + + assert repr(result) == repr(expected) + + def test__from_query_filter_pb_field_filter_unknown_op(self, mock_client): + """ + check field filter with unsupported operator type + """ + field_path = "test_field" + value_pb = _helpers.encode_value(10) + filter_pb = query_pb.StructuredQuery.FieldFilter( + field=query_pb.StructuredQuery.FieldReference(field_path=field_path), + op=query_pb.StructuredQuery.FieldFilter.Operator.OPERATOR_UNSPECIFIED, # Unknown op + value=value_pb, + ) + wrapped_filter_pb = query_pb.StructuredQuery.Filter(field_filter=filter_pb) + + with pytest.raises(TypeError, match="Unexpected FieldFilter operator type"): + BooleanExpression._from_query_filter_pb(wrapped_filter_pb, mock_client) + + def test__from_query_filter_pb_unknown_filter_type(self, mock_client): + """ + test with unsupported filter type + """ + # Test with an unexpected protobuf type + with pytest.raises(TypeError, match="Unexpected filter type"): + BooleanExpression._from_query_filter_pb(document_pb.Value(), mock_client) + + +class TestFunctionExpression: + def test_equals(self): + assert expr.FunctionExpression.sqrt("1") == expr.FunctionExpression.sqrt("1") + assert expr.FunctionExpression.sqrt("1") != expr.FunctionExpression.sqrt("2") + assert expr.FunctionExpression.sqrt("1") != expr.FunctionExpression.sum("1") + assert expr.FunctionExpression.sqrt("1") != object() + + +class TestArray: + """Tests for the array class""" + + def test_array(self): + arg1 = Field.of("field1") + instance = expr.Array([arg1]) + assert instance.name == "array" + assert instance.params == [arg1] + assert repr(instance) == "Array([Field.of('field1')])" + + def test_empty_array(self): + instance = expr.Array([]) + assert instance.name == "array" + assert instance.params == [] + assert repr(instance) == "Array([])" + + def test_array_w_primitives(self): + a = expr.Array([1, Constant.of(2), "3"]) + assert a.name == "array" + assert a.params == [Constant.of(1), Constant.of(2), Constant.of("3")] + assert repr(a) == "Array([Constant.of(1), Constant.of(2), Constant.of('3')])" + + def test_array_w_non_list(self): + with pytest.raises(TypeError): + expr.Array(1) + + +class TestMap: + """Tests for the map class""" + + def test_map(self): + instance = expr.Map({Constant.of("a"): Constant.of("b")}) + assert instance.name == "map" + assert instance.params == [Constant.of("a"), Constant.of("b")] + assert repr(instance) == "Map({'a': 'b'})" + + def test_map_w_primitives(self): + instance = expr.Map({"a": "b", "0": 0, "bool": True}) + assert instance.params == [ + Constant.of("a"), + Constant.of("b"), + Constant.of("0"), + Constant.of(0), + Constant.of("bool"), + Constant.of(True), + ] + assert repr(instance) == "Map({'a': 'b', '0': 0, 'bool': True})" + + def test_empty_map(self): + instance = expr.Map({}) + assert instance.name == "map" + assert instance.params == [] + assert repr(instance) == "Map({})" + + def test_w_exprs(self): + instance = expr.Map({Constant.of("a"): expr.Array([1, 2, 3])}) + assert instance.params == [Constant.of("a"), expr.Array([1, 2, 3])] + assert ( + repr(instance) + == "Map({'a': Array([Constant.of(1), Constant.of(2), Constant.of(3)])})" + ) + + +class TestExpressionessionMethods: + """ + contains test methods for each Expression method + """ + + @pytest.mark.parametrize( + "first,second,expected", + [ + ( + Field.of("a").char_length(), + Field.of("a").char_length(), + True, + ), + ( + Field.of("a").char_length(), + Field.of("b").char_length(), + False, + ), + ( + Field.of("a").char_length(), + Field.of("a").byte_length(), + False, + ), + ( + Field.of("a").char_length(), + Field.of("b").byte_length(), + False, + ), + ( + Constant.of("").byte_length(), + Field.of("").byte_length(), + False, + ), + (Field.of("").byte_length(), Field.of("").byte_length(), True), + ], + ) + def test_equality(self, first, second, expected): + assert (first == second) is expected + + def _make_arg(self, name="Mock"): + class MockExpression(Constant): + def __repr__(self): + return self.value + + arg = MockExpression(name) + return arg + + def test_expression_wrong_first_type(self): + """The first argument should always be an expression or field name""" + expected_message = "must be called on an Expression or a string representing a field. got ." + with pytest.raises(TypeError) as e1: + Expression.logical_minimum(5, 1) + assert str(e1.value) == f"'logical_minimum' {expected_message}" + with pytest.raises(TypeError) as e2: + Expression.sqrt(9) + assert str(e2.value) == f"'sqrt' {expected_message}" + + def test_expression_w_string(self): + """should be able to use string for first argument. Should be interpreted as Field""" + instance = Expression.logical_minimum("first", "second") + assert isinstance(instance.params[0], Field) + assert instance.params[0].path == "first" + + def test_and(self): + arg1 = self._make_arg() + arg2 = self._make_arg() + arg3 = self._make_arg() + instance = expr.And(arg1, arg2, arg3) + assert instance.name == "and" + assert instance.params == [arg1, arg2, arg3] + assert repr(instance) == "And(Mock, Mock, Mock)" + + def test_or(self): + arg1 = self._make_arg("Arg1") + arg2 = self._make_arg("Arg2") + instance = expr.Or(arg1, arg2) + assert instance.name == "or" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Or(Arg1, Arg2)" + + def test_array_get(self): + arg1 = self._make_arg("ArrayField") + arg2 = self._make_arg("Offset") + instance = Expression.array_get(arg1, arg2) + assert instance.name == "array_get" + assert instance.params == [arg1, arg2] + assert repr(instance) == "ArrayField.array_get(Offset)" + infix_istance = arg1.array_get(arg2) + assert infix_istance == instance + + def test_array_contains(self): + arg1 = self._make_arg("ArrayField") + arg2 = self._make_arg("Element") + instance = Expression.array_contains(arg1, arg2) + assert instance.name == "array_contains" + assert instance.params == [arg1, arg2] + assert repr(instance) == "ArrayField.array_contains(Element)" + infix_instance = arg1.array_contains(arg2) + assert infix_instance == instance + + def test_array_contains_any(self): + arg1 = self._make_arg("ArrayField") + arg2 = self._make_arg("Element1") + arg3 = self._make_arg("Element2") + instance = Expression.array_contains_any(arg1, [arg2, arg3]) + assert instance.name == "array_contains_any" + assert isinstance(instance.params[1], expr.Array) + assert instance.params[0] == arg1 + assert instance.params[1].params == [arg2, arg3] + assert ( + repr(instance) + == "ArrayField.array_contains_any(Array([Element1, Element2]))" + ) + infix_instance = arg1.array_contains_any([arg2, arg3]) + assert infix_instance == instance + + def test_exists(self): + arg1 = self._make_arg("Field") + instance = Expression.exists(arg1) + assert instance.name == "exists" + assert instance.params == [arg1] + assert repr(instance) == "Field.exists()" + infix_instance = arg1.exists() + assert infix_instance == instance + + def test_equal(self): + arg1 = self._make_arg("Left") + arg2 = self._make_arg("Right") + instance = Expression.equal(arg1, arg2) + assert instance.name == "equal" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Left.equal(Right)" + infix_instance = arg1.equal(arg2) + assert infix_instance == instance + + def test_greater_than_or_equal(self): + arg1 = self._make_arg("Left") + arg2 = self._make_arg("Right") + instance = Expression.greater_than_or_equal(arg1, arg2) + assert instance.name == "greater_than_or_equal" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Left.greater_than_or_equal(Right)" + infix_instance = arg1.greater_than_or_equal(arg2) + assert infix_instance == instance + + def test_greater_than(self): + arg1 = self._make_arg("Left") + arg2 = self._make_arg("Right") + instance = Expression.greater_than(arg1, arg2) + assert instance.name == "greater_than" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Left.greater_than(Right)" + infix_instance = arg1.greater_than(arg2) + assert infix_instance == instance + + def test_less_than_or_equal(self): + arg1 = self._make_arg("Left") + arg2 = self._make_arg("Right") + instance = Expression.less_than_or_equal(arg1, arg2) + assert instance.name == "less_than_or_equal" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Left.less_than_or_equal(Right)" + infix_instance = arg1.less_than_or_equal(arg2) + assert infix_instance == instance + + def test_less_than(self): + arg1 = self._make_arg("Left") + arg2 = self._make_arg("Right") + instance = Expression.less_than(arg1, arg2) + assert instance.name == "less_than" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Left.less_than(Right)" + infix_instance = arg1.less_than(arg2) + assert infix_instance == instance + + def test_not_equal(self): + arg1 = self._make_arg("Left") + arg2 = self._make_arg("Right") + instance = Expression.not_equal(arg1, arg2) + assert instance.name == "not_equal" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Left.not_equal(Right)" + infix_instance = arg1.not_equal(arg2) + assert infix_instance == instance + + def test_equal_any(self): + arg1 = self._make_arg("Field") + arg2 = self._make_arg("Value1") + arg3 = self._make_arg("Value2") + instance = Expression.equal_any(arg1, [arg2, arg3]) + assert instance.name == "equal_any" + assert isinstance(instance.params[1], expr.Array) + assert instance.params[0] == arg1 + assert instance.params[1].params == [arg2, arg3] + assert repr(instance) == "Field.equal_any(Array([Value1, Value2]))" + infix_instance = arg1.equal_any([arg2, arg3]) + assert infix_instance == instance + + def test_not_equal_any(self): + arg1 = self._make_arg("Field") + arg2 = self._make_arg("Value1") + arg3 = self._make_arg("Value2") + instance = Expression.not_equal_any(arg1, [arg2, arg3]) + assert instance.name == "not_equal_any" + assert isinstance(instance.params[1], expr.Array) + assert instance.params[0] == arg1 + assert instance.params[1].params == [arg2, arg3] + assert repr(instance) == "Field.not_equal_any(Array([Value1, Value2]))" + infix_instance = arg1.not_equal_any([arg2, arg3]) + assert infix_instance == instance + + def test_is_absent(self): + arg1 = self._make_arg("Field") + instance = Expression.is_absent(arg1) + assert instance.name == "is_absent" + assert instance.params == [arg1] + assert repr(instance) == "Field.is_absent()" + infix_instance = arg1.is_absent() + assert infix_instance == instance + + def test_if_absent(self): + arg1 = self._make_arg("Field") + arg2 = self._make_arg("ThenExpression") + instance = Expression.if_absent(arg1, arg2) + assert instance.name == "if_absent" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Field.if_absent(ThenExpression)" + infix_instance = arg1.if_absent(arg2) + assert infix_instance == instance + + def test_is_error(self): + arg1 = self._make_arg("Value") + instance = Expression.is_error(arg1) + assert instance.name == "is_error" + assert instance.params == [arg1] + assert repr(instance) == "Value.is_error()" + infix_instance = arg1.is_error() + assert infix_instance == instance + + def test_if_error(self): + arg1 = self._make_arg("Value") + arg2 = self._make_arg("ThenExpression") + instance = Expression.if_error(arg1, arg2) + assert instance.name == "if_error" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Value.if_error(ThenExpression)" + infix_instance = arg1.if_error(arg2) + assert infix_instance == instance + + def test_not(self): + arg1 = self._make_arg("Condition") + instance = expr.Not(arg1) + assert instance.name == "not" + assert instance.params == [arg1] + assert repr(instance) == "Not(Condition)" + + def test_array_contains_all(self): + arg1 = self._make_arg("ArrayField") + arg2 = self._make_arg("Element1") + arg3 = self._make_arg("Element2") + instance = Expression.array_contains_all(arg1, [arg2, arg3]) + assert instance.name == "array_contains_all" + assert isinstance(instance.params[1], expr.Array) + assert instance.params[0] == arg1 + assert instance.params[1].params == [arg2, arg3] + assert ( + repr(instance) + == "ArrayField.array_contains_all(Array([Element1, Element2]))" + ) + infix_instance = arg1.array_contains_all([arg2, arg3]) + assert infix_instance == instance + + def test_ends_with(self): + arg1 = self._make_arg("Expression") + arg2 = self._make_arg("Postfix") + instance = Expression.ends_with(arg1, arg2) + assert instance.name == "ends_with" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Expression.ends_with(Postfix)" + infix_instance = arg1.ends_with(arg2) + assert infix_instance == instance + + def test_conditional(self): + arg1 = self._make_arg("Condition") + arg2 = self._make_arg("ThenExpression") + arg3 = self._make_arg("ElseExpression") + instance = expr.Conditional(arg1, arg2, arg3) + assert instance.name == "conditional" + assert instance.params == [arg1, arg2, arg3] + assert ( + repr(instance) == "Conditional(Condition, ThenExpression, ElseExpression)" + ) + + def test_like(self): + arg1 = self._make_arg("Expression") + arg2 = self._make_arg("Pattern") + instance = Expression.like(arg1, arg2) + assert instance.name == "like" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Expression.like(Pattern)" + infix_instance = arg1.like(arg2) + assert infix_instance == instance + + def test_regex_contains(self): + arg1 = self._make_arg("Expression") + arg2 = self._make_arg("Regex") + instance = Expression.regex_contains(arg1, arg2) + assert instance.name == "regex_contains" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Expression.regex_contains(Regex)" + infix_instance = arg1.regex_contains(arg2) + assert infix_instance == instance + + def test_regex_match(self): + arg1 = self._make_arg("Expression") + arg2 = self._make_arg("Regex") + instance = Expression.regex_match(arg1, arg2) + assert instance.name == "regex_match" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Expression.regex_match(Regex)" + infix_instance = arg1.regex_match(arg2) + assert infix_instance == instance + + def test_starts_with(self): + arg1 = self._make_arg("Expression") + arg2 = self._make_arg("Prefix") + instance = Expression.starts_with(arg1, arg2) + assert instance.name == "starts_with" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Expression.starts_with(Prefix)" + infix_instance = arg1.starts_with(arg2) + assert infix_instance == instance + + def test_string_contains(self): + arg1 = self._make_arg("Expression") + arg2 = self._make_arg("Substring") + instance = Expression.string_contains(arg1, arg2) + assert instance.name == "string_contains" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Expression.string_contains(Substring)" + infix_instance = arg1.string_contains(arg2) + assert infix_instance == instance + + def test_xor(self): + arg1 = self._make_arg("Condition1") + arg2 = self._make_arg("Condition2") + instance = expr.Xor([arg1, arg2]) + assert instance.name == "xor" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Xor(Condition1, Condition2)" + + def test_divide(self): + arg1 = self._make_arg("Left") + arg2 = self._make_arg("Right") + instance = Expression.divide(arg1, arg2) + assert instance.name == "divide" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Left.divide(Right)" + infix_instance = arg1.divide(arg2) + assert infix_instance == instance + + def test_logical_maximum(self): + arg1 = self._make_arg("A1") + arg2 = self._make_arg("A2") + arg3 = self._make_arg("A3") + instance = Expression.logical_maximum(arg1, arg2, arg3) + assert instance.name == "maximum" + assert instance.params == [arg1, arg2, arg3] + assert repr(instance) == "A1.logical_maximum(A2, A3)" + infix_instance = arg1.logical_maximum(arg2, arg3) + assert infix_instance == instance + + def test_logical_minimum(self): + arg1 = self._make_arg("A1") + arg2 = self._make_arg("A2") + arg3 = self._make_arg("A3") + instance = Expression.logical_minimum(arg1, arg2, arg3) + assert instance.name == "minimum" + assert instance.params == [arg1, arg2, arg3] + assert repr(instance) == "A1.logical_minimum(A2, A3)" + infix_instance = arg1.logical_minimum(arg2, arg3) + assert infix_instance == instance + + def test_to_lower(self): + arg1 = self._make_arg("Input") + instance = Expression.to_lower(arg1) + assert instance.name == "to_lower" + assert instance.params == [arg1] + assert repr(instance) == "Input.to_lower()" + infix_instance = arg1.to_lower() + assert infix_instance == instance + + def test_to_upper(self): + arg1 = self._make_arg("Input") + instance = Expression.to_upper(arg1) + assert instance.name == "to_upper" + assert instance.params == [arg1] + assert repr(instance) == "Input.to_upper()" + infix_instance = arg1.to_upper() + assert infix_instance == instance + + def test_trim(self): + arg1 = self._make_arg("Input") + instance = Expression.trim(arg1) + assert instance.name == "trim" + assert instance.params == [arg1] + assert repr(instance) == "Input.trim()" + infix_instance = arg1.trim() + assert infix_instance == instance + + def test_string_reverse(self): + arg1 = self._make_arg("Input") + instance = Expression.string_reverse(arg1) + assert instance.name == "string_reverse" + assert instance.params == [arg1] + assert repr(instance) == "Input.string_reverse()" + infix_instance = arg1.string_reverse() + assert infix_instance == instance + + def test_substring(self): + arg1 = self._make_arg("Input") + arg2 = self._make_arg("Position") + instance = Expression.substring(arg1, arg2) + assert instance.name == "substring" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Input.substring(Position)" + infix_instance = arg1.substring(arg2) + assert infix_instance == instance + + def test_substring_w_length(self): + arg1 = self._make_arg("Input") + arg2 = self._make_arg("Position") + arg3 = self._make_arg("Length") + instance = Expression.substring(arg1, arg2, arg3) + assert instance.name == "substring" + assert instance.params == [arg1, arg2, arg3] + assert repr(instance) == "Input.substring(Position, Length)" + infix_instance = arg1.substring(arg2, arg3) + assert infix_instance == instance + + def test_join(self): + arg1 = self._make_arg("Array") + arg2 = self._make_arg("Separator") + instance = Expression.join(arg1, arg2) + assert instance.name == "join" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Array.join(Separator)" + infix_instance = arg1.join(arg2) + assert infix_instance == instance + + def test_map_get(self): + arg1 = self._make_arg("Map") + arg2 = "key" + instance = Expression.map_get(arg1, arg2) + assert instance.name == "map_get" + assert instance.params == [arg1, Constant.of(arg2)] + assert repr(instance) == "Map.map_get(Constant.of('key'))" + infix_instance = arg1.map_get(Constant.of(arg2)) + assert infix_instance == instance + + def test_map_remove(self): + arg1 = self._make_arg("Map") + arg2 = "key" + instance = Expression.map_remove(arg1, arg2) + assert instance.name == "map_remove" + assert instance.params == [arg1, Constant.of(arg2)] + assert repr(instance) == "Map.map_remove(Constant.of('key'))" + infix_instance = arg1.map_remove(Constant.of(arg2)) + assert infix_instance == instance + + def test_map_merge(self): + arg1 = expr.Map({"a": 1}) + arg2 = expr.Map({"b": 2}) + arg3 = {"c": 3} + instance = Expression.map_merge(arg1, arg2, arg3) + assert instance.name == "map_merge" + assert instance.params == [arg1, arg2, expr.Map(arg3)] + assert repr(instance) == "Map({'a': 1}).map_merge(Map({'b': 2}), Map({'c': 3}))" + infix_instance = arg1.map_merge(arg2, arg3) + assert infix_instance == instance + + def test_mod(self): + arg1 = self._make_arg("Left") + arg2 = self._make_arg("Right") + instance = Expression.mod(arg1, arg2) + assert instance.name == "mod" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Left.mod(Right)" + infix_instance = arg1.mod(arg2) + assert infix_instance == instance + + def test_multiply(self): + arg1 = self._make_arg("Left") + arg2 = self._make_arg("Right") + instance = Expression.multiply(arg1, arg2) + assert instance.name == "multiply" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Left.multiply(Right)" + infix_instance = arg1.multiply(arg2) + assert infix_instance == instance + + def test_string_concat(self): + arg1 = self._make_arg("Str1") + arg2 = self._make_arg("Str2") + arg3 = self._make_arg("Str3") + instance = Expression.string_concat(arg1, arg2, arg3) + assert instance.name == "string_concat" + assert instance.params == [arg1, arg2, arg3] + assert repr(instance) == "Str1.string_concat(Str2, Str3)" + infix_instance = arg1.string_concat(arg2, arg3) + assert infix_instance == instance + + def test_subtract(self): + arg1 = self._make_arg("Left") + arg2 = self._make_arg("Right") + instance = Expression.subtract(arg1, arg2) + assert instance.name == "subtract" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Left.subtract(Right)" + infix_instance = arg1.subtract(arg2) + assert infix_instance == instance + + def test_current_timestamp(self): + instance = expr.CurrentTimestamp() + assert instance.name == "current_timestamp" + assert instance.params == [] + assert repr(instance) == "CurrentTimestamp()" + + def test_timestamp_add(self): + arg1 = self._make_arg("Timestamp") + arg2 = self._make_arg("Unit") + arg3 = self._make_arg("Amount") + instance = Expression.timestamp_add(arg1, arg2, arg3) + assert instance.name == "timestamp_add" + assert instance.params == [arg1, arg2, arg3] + assert repr(instance) == "Timestamp.timestamp_add(Unit, Amount)" + infix_instance = arg1.timestamp_add(arg2, arg3) + assert infix_instance == instance + + def test_timestamp_subtract(self): + arg1 = self._make_arg("Timestamp") + arg2 = self._make_arg("Unit") + arg3 = self._make_arg("Amount") + instance = Expression.timestamp_subtract(arg1, arg2, arg3) + assert instance.name == "timestamp_subtract" + assert instance.params == [arg1, arg2, arg3] + assert repr(instance) == "Timestamp.timestamp_subtract(Unit, Amount)" + infix_instance = arg1.timestamp_subtract(arg2, arg3) + assert infix_instance == instance + + def test_timestamp_to_unix_micros(self): + arg1 = self._make_arg("Input") + instance = Expression.timestamp_to_unix_micros(arg1) + assert instance.name == "timestamp_to_unix_micros" + assert instance.params == [arg1] + assert repr(instance) == "Input.timestamp_to_unix_micros()" + infix_instance = arg1.timestamp_to_unix_micros() + assert infix_instance == instance + + def test_timestamp_to_unix_millis(self): + arg1 = self._make_arg("Input") + instance = Expression.timestamp_to_unix_millis(arg1) + assert instance.name == "timestamp_to_unix_millis" + assert instance.params == [arg1] + assert repr(instance) == "Input.timestamp_to_unix_millis()" + infix_instance = arg1.timestamp_to_unix_millis() + assert infix_instance == instance + + def test_timestamp_to_unix_seconds(self): + arg1 = self._make_arg("Input") + instance = Expression.timestamp_to_unix_seconds(arg1) + assert instance.name == "timestamp_to_unix_seconds" + assert instance.params == [arg1] + assert repr(instance) == "Input.timestamp_to_unix_seconds()" + infix_instance = arg1.timestamp_to_unix_seconds() + assert infix_instance == instance + + def test_unix_micros_to_timestamp(self): + arg1 = self._make_arg("Input") + instance = Expression.unix_micros_to_timestamp(arg1) + assert instance.name == "unix_micros_to_timestamp" + assert instance.params == [arg1] + assert repr(instance) == "Input.unix_micros_to_timestamp()" + infix_instance = arg1.unix_micros_to_timestamp() + assert infix_instance == instance + + def test_unix_millis_to_timestamp(self): + arg1 = self._make_arg("Input") + instance = Expression.unix_millis_to_timestamp(arg1) + assert instance.name == "unix_millis_to_timestamp" + assert instance.params == [arg1] + assert repr(instance) == "Input.unix_millis_to_timestamp()" + infix_instance = arg1.unix_millis_to_timestamp() + assert infix_instance == instance + + def test_unix_seconds_to_timestamp(self): + arg1 = self._make_arg("Input") + instance = Expression.unix_seconds_to_timestamp(arg1) + assert instance.name == "unix_seconds_to_timestamp" + assert instance.params == [arg1] + assert repr(instance) == "Input.unix_seconds_to_timestamp()" + infix_instance = arg1.unix_seconds_to_timestamp() + assert infix_instance == instance + + def test_euclidean_distance(self): + arg1 = self._make_arg("Vector1") + arg2 = self._make_arg("Vector2") + instance = Expression.euclidean_distance(arg1, arg2) + assert instance.name == "euclidean_distance" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Vector1.euclidean_distance(Vector2)" + infix_instance = arg1.euclidean_distance(arg2) + assert infix_instance == instance + + def test_cosine_distance(self): + arg1 = self._make_arg("Vector1") + arg2 = self._make_arg("Vector2") + instance = Expression.cosine_distance(arg1, arg2) + assert instance.name == "cosine_distance" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Vector1.cosine_distance(Vector2)" + infix_instance = arg1.cosine_distance(arg2) + assert infix_instance == instance + + def test_dot_product(self): + arg1 = self._make_arg("Vector1") + arg2 = self._make_arg("Vector2") + instance = Expression.dot_product(arg1, arg2) + assert instance.name == "dot_product" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Vector1.dot_product(Vector2)" + infix_instance = arg1.dot_product(arg2) + assert infix_instance == instance + + @pytest.mark.parametrize( + "method", ["euclidean_distance", "cosine_distance", "dot_product"] + ) + @pytest.mark.parametrize( + "input", [Vector([1.0, 2.0]), [1, 2], Constant.of(Vector([1.0, 2.0])), []] + ) + def test_vector_ctor(self, method, input): + """ + test constructing various vector expressions with + different inputs + """ + arg1 = self._make_arg("VectorRef") + instance = getattr(arg1, method)(input) + assert instance.name == method + got_second_param = instance.params[1] + assert isinstance(got_second_param, Constant) + assert isinstance(got_second_param.value, Vector) + + def test_vector_length(self): + arg1 = self._make_arg("Array") + instance = Expression.vector_length(arg1) + assert instance.name == "vector_length" + assert instance.params == [arg1] + assert repr(instance) == "Array.vector_length()" + infix_instance = arg1.vector_length() + assert infix_instance == instance + + def test_add(self): + arg1 = self._make_arg("Left") + arg2 = self._make_arg("Right") + instance = Expression.add(arg1, arg2) + assert instance.name == "add" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Left.add(Right)" + infix_instance = arg1.add(arg2) + assert infix_instance == instance + + def test_abs(self): + arg1 = self._make_arg("Value") + instance = Expression.abs(arg1) + assert instance.name == "abs" + assert instance.params == [arg1] + assert repr(instance) == "Value.abs()" + infix_instance = arg1.abs() + assert infix_instance == instance + + def test_ceil(self): + arg1 = self._make_arg("Value") + instance = Expression.ceil(arg1) + assert instance.name == "ceil" + assert instance.params == [arg1] + assert repr(instance) == "Value.ceil()" + infix_instance = arg1.ceil() + assert infix_instance == instance + + def test_exp(self): + arg1 = self._make_arg("Value") + instance = Expression.exp(arg1) + assert instance.name == "exp" + assert instance.params == [arg1] + assert repr(instance) == "Value.exp()" + infix_instance = arg1.exp() + assert infix_instance == instance + + def test_floor(self): + arg1 = self._make_arg("Value") + instance = Expression.floor(arg1) + assert instance.name == "floor" + assert instance.params == [arg1] + assert repr(instance) == "Value.floor()" + infix_instance = arg1.floor() + assert infix_instance == instance + + def test_ln(self): + arg1 = self._make_arg("Value") + instance = Expression.ln(arg1) + assert instance.name == "ln" + assert instance.params == [arg1] + assert repr(instance) == "Value.ln()" + infix_instance = arg1.ln() + assert infix_instance == instance + + def test_log(self): + arg1 = self._make_arg("Value") + arg2 = self._make_arg("Base") + instance = Expression.log(arg1, arg2) + assert instance.name == "log" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Value.log(Base)" + infix_instance = arg1.log(arg2) + assert infix_instance == instance + + def test_log10(self): + arg1 = self._make_arg("Value") + instance = Expression.log10(arg1) + assert instance.name == "log10" + assert instance.params == [arg1] + assert repr(instance) == "Value.log10()" + infix_instance = arg1.log10() + assert infix_instance == instance + + def test_pow(self): + arg1 = self._make_arg("Value") + arg2 = self._make_arg("Exponent") + instance = Expression.pow(arg1, arg2) + assert instance.name == "pow" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Value.pow(Exponent)" + infix_instance = arg1.pow(arg2) + assert infix_instance == instance + + def test_round(self): + arg1 = self._make_arg("Value") + instance = Expression.round(arg1) + assert instance.name == "round" + assert instance.params == [arg1] + assert repr(instance) == "Value.round()" + infix_instance = arg1.round() + assert infix_instance == instance + + def test_sqrt(self): + arg1 = self._make_arg("Value") + instance = Expression.sqrt(arg1) + assert instance.name == "sqrt" + assert instance.params == [arg1] + assert repr(instance) == "Value.sqrt()" + infix_instance = arg1.sqrt() + assert infix_instance == instance + + def test_array_length(self): + arg1 = self._make_arg("Array") + instance = Expression.array_length(arg1) + assert instance.name == "array_length" + assert instance.params == [arg1] + assert repr(instance) == "Array.array_length()" + infix_instance = arg1.array_length() + assert infix_instance == instance + + def test_array_reverse(self): + arg1 = self._make_arg("Array") + instance = Expression.array_reverse(arg1) + assert instance.name == "array_reverse" + assert instance.params == [arg1] + assert repr(instance) == "Array.array_reverse()" + infix_instance = arg1.array_reverse() + assert infix_instance == instance + + def test_array_concat(self): + arg1 = self._make_arg("ArrayRef1") + arg2 = self._make_arg("ArrayRef2") + instance = Expression.array_concat(arg1, arg2) + assert instance.name == "array_concat" + assert instance.params == [arg1, arg2] + assert repr(instance) == "ArrayRef1.array_concat(ArrayRef2)" + infix_instance = arg1.array_concat(arg2) + assert infix_instance == instance + + def test_array_concat_multiple(self): + arg1 = expr.Array([Constant.of(0)]) + arg2 = Field.of("ArrayRef2") + arg3 = Field.of("ArrayRef3") + arg4 = [self._make_arg("Constant")] + instance = arg1.array_concat(arg2, arg3, arg4) + assert instance.name == "array_concat" + assert instance.params == [arg1, arg2, arg3, expr.Array(arg4)] + assert ( + repr(instance) + == "Array([Constant.of(0)]).array_concat(Field.of('ArrayRef2'), Field.of('ArrayRef3'), Array([Constant]))" + ) + + def test_byte_length(self): + arg1 = self._make_arg("Expression") + instance = Expression.byte_length(arg1) + assert instance.name == "byte_length" + assert instance.params == [arg1] + assert repr(instance) == "Expression.byte_length()" + infix_instance = arg1.byte_length() + assert infix_instance == instance + + def test_char_length(self): + arg1 = self._make_arg("Expression") + instance = Expression.char_length(arg1) + assert instance.name == "char_length" + assert instance.params == [arg1] + assert repr(instance) == "Expression.char_length()" + infix_instance = arg1.char_length() + assert infix_instance == instance + + def test_concat(self): + arg1 = self._make_arg("First") + arg2 = self._make_arg("Second") + arg3 = "Third" + instance = Expression.concat(arg1, arg2, arg3) + assert instance.name == "concat" + assert instance.params == [arg1, arg2, Constant.of(arg3)] + assert repr(instance) == "First.concat(Second, Constant.of('Third'))" + infix_instance = arg1.concat(arg2, arg3) + assert infix_instance == instance + + def test_length(self): + arg1 = self._make_arg("Expression") + instance = Expression.length(arg1) + assert instance.name == "length" + assert instance.params == [arg1] + assert repr(instance) == "Expression.length()" + infix_instance = arg1.length() + assert infix_instance == instance + + def test_collection_id(self): + arg1 = self._make_arg("Value") + instance = Expression.collection_id(arg1) + assert instance.name == "collection_id" + assert instance.params == [arg1] + assert repr(instance) == "Value.collection_id()" + infix_instance = arg1.collection_id() + assert infix_instance == instance + + def test_document_id(self): + arg1 = self._make_arg("Value") + instance = Expression.document_id(arg1) + assert instance.name == "document_id" + assert instance.params == [arg1] + assert repr(instance) == "Value.document_id()" + infix_instance = arg1.document_id() + assert infix_instance == instance + + def test_sum(self): + arg1 = self._make_arg("Value") + instance = Expression.sum(arg1) + assert instance.name == "sum" + assert instance.params == [arg1] + assert repr(instance) == "Value.sum()" + infix_instance = arg1.sum() + assert infix_instance == instance + + def test_average(self): + arg1 = self._make_arg("Value") + instance = Expression.average(arg1) + assert instance.name == "average" + assert instance.params == [arg1] + assert repr(instance) == "Value.average()" + infix_instance = arg1.average() + assert infix_instance == instance + + def test_count(self): + arg1 = self._make_arg("Value") + instance = Expression.count(arg1) + assert instance.name == "count" + assert instance.params == [arg1] + assert repr(instance) == "Value.count()" + infix_instance = arg1.count() + assert infix_instance == instance + + def test_base_count(self): + instance = expr.Count() + assert instance.name == "count" + assert instance.params == [] + assert repr(instance) == "Count()" + + def test_count_if(self): + arg1 = self._make_arg("Value") + instance = Expression.count_if(arg1) + assert instance.name == "count_if" + assert instance.params == [arg1] + assert repr(instance) == "Value.count_if()" + infix_instance = arg1.count_if() + assert infix_instance == instance + + def test_count_distinct(self): + arg1 = self._make_arg("Value") + instance = Expression.count_distinct(arg1) + assert instance.name == "count_distinct" + assert instance.params == [arg1] + assert repr(instance) == "Value.count_distinct()" + infix_instance = arg1.count_distinct() + assert infix_instance == instance + + def test_minimum(self): + arg1 = self._make_arg("Value") + instance = Expression.minimum(arg1) + assert instance.name == "minimum" + assert instance.params == [arg1] + assert repr(instance) == "Value.minimum()" + infix_instance = arg1.minimum() + assert infix_instance == instance + + def test_maximum(self): + arg1 = self._make_arg("Value") + instance = Expression.maximum(arg1) + assert instance.name == "maximum" + assert instance.params == [arg1] + assert repr(instance) == "Value.maximum()" + infix_instance = arg1.maximum() + assert infix_instance == instance diff --git a/tests/unit/v1/test_pipeline_result.py b/tests/unit/v1/test_pipeline_result.py new file mode 100644 index 000000000..3650074bc --- /dev/null +++ b/tests/unit/v1/test_pipeline_result.py @@ -0,0 +1,498 @@ +# Copyright 2025 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +import mock +import pytest + +from google.cloud.firestore_v1.types.firestore import ExecutePipelineResponse +from google.cloud.firestore_v1.pipeline_expressions import Constant +from google.cloud.firestore_v1.pipeline_result import PipelineResult +from google.cloud.firestore_v1.pipeline_result import PipelineSnapshot +from google.cloud.firestore_v1.pipeline_result import PipelineStream +from google.cloud.firestore_v1.pipeline_result import AsyncPipelineStream +from google.cloud.firestore_v1.query_profile import QueryExplainError +from google.cloud.firestore_v1.query_profile import PipelineExplainOptions +from google.cloud.firestore_v1._helpers import encode_value +from google.cloud.firestore_v1.types.document import Document +from google.protobuf.timestamp_pb2 import Timestamp + + +_mock_stream_responses = [ + ExecutePipelineResponse( + results=[Document(name="projects/p/databases/d/documents/c/d1", fields={})], + execution_time=Timestamp(seconds=1, nanos=2), + explain_stats={"data": {}}, + ), + ExecutePipelineResponse( + results=[Document(name="projects/p/databases/d/documents/c/d2", fields={})], + execution_time=Timestamp(seconds=3, nanos=4), + ), +] + + +class TestPipelineResult: + def _make_one(self, *args, **kwargs): + if not args: + # use defaults if not passed + args = [mock.Mock(), {}] + return PipelineResult(*args, **kwargs) + + def test_ref(self): + expected = object() + instance = self._make_one(ref=expected) + assert instance.ref == expected + # should be None if not set + assert self._make_one().ref is None + + def test_id(self): + ref = mock.Mock() + ref.id = "test" + instance = self._make_one(ref=ref) + assert instance.id == "test" + # should be None if not set + assert self._make_one().id is None + + def test_create_time(self): + expected = object() + instance = self._make_one(create_time=expected) + assert instance.create_time == expected + # should be None if not set + assert self._make_one().create_time is None + + def test_update_time(self): + expected = object() + instance = self._make_one(update_time=expected) + assert instance.update_time == expected + # should be None if not set + assert self._make_one().update_time is None + + def test_exection_time(self): + expected = object() + instance = self._make_one(execution_time=expected) + assert instance.execution_time == expected + # should raise if not set + with pytest.raises(ValueError) as e: + self._make_one().execution_time + assert "execution_time" in e + + @pytest.mark.parametrize( + "first,second,result", + [ + ((object(), {}), (object(), {}), True), + ((object(), {1: 1}), (object(), {1: 1}), True), + ((object(), {1: 1}), (object(), {2: 2}), False), + ((object(), {}, "ref"), (object(), {}, "ref"), True), + ((object(), {}, "ref"), (object(), {}, "diff"), False), + ((object(), {1: 1}, "ref"), (object(), {1: 1}, "ref"), True), + ((object(), {1: 1}, "ref"), (object(), {2: 2}, "ref"), False), + ((object(), {1: 1}, "ref"), (object(), {1: 1}, "diff"), False), + ( + (object(), {1: 1}, "ref", 1, 2, 3), + (object(), {1: 1}, "ref", 4, 5, 6), + True, + ), + ], + ) + def test_eq(self, first, second, result): + first_obj = self._make_one(*first) + second_obj = self._make_one(*second) + assert (first_obj == second_obj) is result + + def test_eq_wrong_type(self): + instance = self._make_one() + result = instance == object() + assert result is False + + def test_data(self): + from google.cloud.firestore_v1.types.document import Value + + client = mock.Mock() + data = {"str": Value(string_value="hello world"), "int": Value(integer_value=5)} + instance = self._make_one(client, data) + got = instance.data() + assert len(got) == 2 + assert got["str"] == "hello world" + assert got["int"] == 5 + + def test_data_none(self): + client = object() + data = None + instance = self._make_one(client, data) + assert instance.data() is None + + def test_data_call(self): + """ + ensure decode_dict is called on .data + """ + client = object() + data = {"hello": "world"} + instance = self._make_one(client, data) + with mock.patch( + "google.cloud.firestore_v1._helpers.decode_dict" + ) as decode_mock: + got = instance.data() + decode_mock.assert_called_once_with(data, client) + assert got == decode_mock.return_value + + def test_get(self): + from google.cloud.firestore_v1.types.document import Value + + client = object() + data = {"key": Value(string_value="hello world")} + instance = self._make_one(client, data) + got = instance.get("key") + assert got == "hello world" + + def test_get_nested(self): + from google.cloud.firestore_v1.types.document import Value + + client = object() + data = {"first": {"second": Value(string_value="hello world")}} + instance = self._make_one(client, data) + got = instance.get("first.second") + assert got == "hello world" + + def test_get_field_path(self): + from google.cloud.firestore_v1.types.document import Value + from google.cloud.firestore_v1.field_path import FieldPath + + client = object() + data = {"first": {"second": Value(string_value="hello world")}} + path = FieldPath.from_string("first.second") + instance = self._make_one(client, data) + got = instance.get(path) + assert got == "hello world" + + def test_get_failure(self): + """ + test calling get on value not in data + """ + client = object() + data = {} + instance = self._make_one(client, data) + with pytest.raises(KeyError): + instance.get("key") + + def test_get_call(self): + """ + ensure decode_value is called on .get() + """ + client = object() + data = {"key": "value"} + instance = self._make_one(client, data) + with mock.patch( + "google.cloud.firestore_v1._helpers.decode_value" + ) as decode_mock: + got = instance.get("key") + decode_mock.assert_called_once_with("value", client) + assert got == decode_mock.return_value + + +class TestPipelineSnapshot: + def _make_one(self, *args, **kwargs): + if not args: + # use defaults if not passed + args = [[], mock.Mock()] + return PipelineSnapshot(*args, **kwargs) + + def test_ctor(self): + in_arr = [1, 2, 3] + expected_type = object() + expected_pipeline = mock.Mock() + expected_transaction = object() + expected_read_time = 123 + expected_explain_options = object() + expected_addtl_options = {} + source = PipelineStream( + expected_type, + expected_pipeline, + expected_transaction, + expected_read_time, + expected_explain_options, + expected_addtl_options, + ) + instance = self._make_one(in_arr, source) + assert instance._return_type == expected_type + assert instance.pipeline == expected_pipeline + assert instance._client == expected_pipeline._client + assert instance._additonal_options == expected_addtl_options + assert instance._explain_options == expected_explain_options + assert instance._explain_stats is None + assert instance._started is True + assert instance.execution_time is None + assert instance.transaction == expected_transaction + assert instance._read_time == expected_read_time + + def test_list_methods(self): + instance = self._make_one(list(range(10)), mock.Mock()) + assert isinstance(instance, list) + assert len(instance) == 10 + assert instance[0] == 0 + assert instance[-1] == 9 + + def test_explain_stats(self): + instance = self._make_one() + expected_stats = mock.Mock() + instance._explain_stats = expected_stats + assert instance.explain_stats == expected_stats + # test different failure modes + instance._explain_stats = None + instance._explain_options = None + # fail if explain_stats set without explain_options + with pytest.raises(QueryExplainError) as e: + instance.explain_stats + assert "explain_options not set" in str(e) + # fail if explain_stats missing + instance._explain_options = object() + with pytest.raises(QueryExplainError) as e: + instance.explain_stats + assert "explain_stats not found" in str(e) + + +class SharedStreamTests: + """ + Shared test logic for PipelineStream and AsyncPipelineStream + """ + + def _make_one(self, *args, **kwargs): + raise NotImplementedError + + def _mock_init_args(self): + # return default mocks for all init args + from google.cloud.firestore_v1.pipeline import Pipeline + + return { + "return_type": PipelineResult, + "pipeline": Pipeline(mock.Mock()), + "transaction": None, + "read_time": None, + "explain_options": None, + "additional_options": {}, + } + + def test_explain_stats(self): + instance = self._make_one() + expected_stats = mock.Mock() + instance._started = True + instance._explain_stats = expected_stats + assert instance.explain_stats == expected_stats + # test different failure modes + instance._explain_stats = None + instance._explain_options = None + # fail if explain_stats set without explain_options + with pytest.raises(QueryExplainError) as e: + instance.explain_stats + assert "explain_options not set" in str(e) + # fail if explain_stats missing + instance._explain_options = object() + with pytest.raises(QueryExplainError) as e: + instance.explain_stats + assert "explain_stats not found" in str(e) + # fail if not started + instance._started = False + with pytest.raises(QueryExplainError) as e: + instance.explain_stats + assert "not available until query is complete" in str(e) + + @pytest.mark.parametrize( + "init_kwargs,expected_options", + [ + ( + {"explain_options": PipelineExplainOptions()}, + {"explain_options": encode_value({"mode": "analyze"})}, + ), + ( + {"explain_options": PipelineExplainOptions(mode="explain")}, + {"explain_options": encode_value({"mode": "explain"})}, + ), + ( + {"additional_options": {"explain_options": Constant("custom")}}, + {"explain_options": encode_value("custom")}, + ), + ( + {"additional_options": {"explain_options": encode_value("custom")}}, + {"explain_options": encode_value("custom")}, + ), + ( + { + "explain_options": PipelineExplainOptions(), + "additional_options": {"explain_options": Constant.of("override")}, + }, + {"explain_options": encode_value("override")}, + ), + ], + ) + def test_build_request_options(self, init_kwargs, expected_options): + """ + Certain Arguments to PipelineStream should be passed to `options` field in proto request + """ + instance = self._make_one(**init_kwargs) + request = instance._build_request() + options = dict(request.structured_pipeline.options) + assert options == expected_options + assert len(options) == len(expected_options) + + def test_build_request_transaction(self): + """Ensure transaction is passed down when building request""" + from google.cloud.firestore_v1.transaction import Transaction + + expected_id = b"expected" + transaction = Transaction(mock.Mock()) + transaction._id = expected_id + instance = self._make_one(transaction=transaction) + request = instance._build_request() + assert request.transaction == expected_id + + def test_build_request_read_time(self): + """Ensure readtime is passed down when building request""" + import datetime + + ts = datetime.datetime.now() + instance = self._make_one(read_time=ts) + request = instance._build_request() + assert request.read_time.timestamp() == ts.timestamp() + + +class TestPipelineStream(SharedStreamTests): + def _make_one(self, **kwargs): + init_kwargs = self._mock_init_args() + init_kwargs.update(kwargs) + return PipelineStream(**init_kwargs) + + def test_explain_stats(self): + instance = self._make_one() + expected_stats = mock.Mock() + instance._started = True + instance._explain_stats = expected_stats + assert instance.explain_stats == expected_stats + # test different failure modes + instance._explain_stats = None + instance._explain_options = None + # fail if explain_stats set without explain_options + with pytest.raises(QueryExplainError) as e: + instance.explain_stats + assert "explain_options not set" in str(e) + # fail if explain_stats missing + instance._explain_options = object() + with pytest.raises(QueryExplainError) as e: + instance.explain_stats + assert "explain_stats not found" in str(e) + # fail if not started + instance._started = False + with pytest.raises(QueryExplainError) as e: + instance.explain_stats + assert "not available until query is complete" in str(e) + + def test_iter(self): + pipeline = mock.Mock() + pipeline._client.project = "project-id" + pipeline._client._database = "database-id" + pipeline._client.document.side_effect = lambda path: mock.Mock( + id=path.split("/")[-1] + ) + pipeline._to_pb.return_value = {} + + instance = self._make_one(pipeline=pipeline) + + instance._client._firestore_api.execute_pipeline.return_value = ( + _mock_stream_responses + ) + + results = list(instance) + + assert len(results) == 2 + assert isinstance(results[0], PipelineResult) + assert results[0].id == "d1" + assert isinstance(results[1], PipelineResult) + assert results[1].id == "d2" + + assert instance.execution_time.seconds == 1 + assert instance.execution_time.nanos == 2 + + # expect empty stats + got_stats = instance.explain_stats.get_raw().data + assert got_stats.value == b"" + + instance._client._firestore_api.execute_pipeline.assert_called_once() + + def test_double_iterate(self): + instance = self._make_one() + instance._client._firestore_api.execute_pipeline.return_value = [] + # consume the iterator + list(instance) + with pytest.raises(RuntimeError): + list(instance) + + +class TestAsyncPipelineStream(SharedStreamTests): + def _make_one(self, **kwargs): + init_kwargs = self._mock_init_args() + init_kwargs.update(kwargs) + return AsyncPipelineStream(**init_kwargs) + + @pytest.mark.asyncio + async def test_aiter(self): + pipeline = mock.Mock() + pipeline._client.project = "project-id" + pipeline._client._database = "database-id" + pipeline._client.document.side_effect = lambda path: mock.Mock( + id=path.split("/")[-1] + ) + pipeline._to_pb.return_value = {} + + instance = self._make_one(pipeline=pipeline) + + async def async_gen(items): + for item in items: + yield item + + instance._client._firestore_api.execute_pipeline = mock.AsyncMock( + return_value=async_gen(_mock_stream_responses) + ) + + results = [item async for item in instance] + + assert len(results) == 2 + assert isinstance(results[0], PipelineResult) + assert results[0].id == "d1" + assert isinstance(results[1], PipelineResult) + assert results[1].id == "d2" + + assert instance.execution_time.seconds == 1 + assert instance.execution_time.nanos == 2 + + # expect empty stats + got_stats = instance.explain_stats.get_raw().data + assert got_stats.value == b"" + + instance._client._firestore_api.execute_pipeline.assert_called_once() + + @pytest.mark.asyncio + async def test_double_iterate(self): + instance = self._make_one() + + async def async_gen(items): + for item in items: + yield item # pragma: NO COVER + + # mock the api call to avoid real network requests + instance._client._firestore_api.execute_pipeline = mock.AsyncMock( + return_value=async_gen([]) + ) + + # consume the iterator + [item async for item in instance] + # should fail on second attempt + with pytest.raises(RuntimeError): + [item async for item in instance] diff --git a/tests/unit/v1/test_pipeline_source.py b/tests/unit/v1/test_pipeline_source.py new file mode 100644 index 000000000..d6665d4bc --- /dev/null +++ b/tests/unit/v1/test_pipeline_source.py @@ -0,0 +1,127 @@ +# Copyright 2025 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License +import mock + +from google.cloud.firestore_v1.pipeline_source import PipelineSource +from google.cloud.firestore_v1.pipeline import Pipeline +from google.cloud.firestore_v1.async_pipeline import AsyncPipeline +from google.cloud.firestore_v1 import pipeline_stages as stages +from google.cloud.firestore_v1.base_document import BaseDocumentReference +from google.cloud.firestore_v1.query import Query +from google.cloud.firestore_v1.async_query import AsyncQuery + +from tests.unit.v1._test_helpers import make_async_client +from tests.unit.v1._test_helpers import make_client + + +class TestPipelineSource: + _expected_pipeline_type = Pipeline + + def _make_client(self): + return make_client() + + def _make_query(self): + return Query(mock.Mock()) + + def test_make_from_client(self): + instance = self._make_client().pipeline() + assert isinstance(instance, PipelineSource) + + def test_create_pipeline(self): + instance = self._make_client().pipeline() + ppl = instance._create_pipeline(None) + assert isinstance(ppl, self._expected_pipeline_type) + + def test_create_from_mock(self): + mock_query = mock.Mock() + expected = object() + mock_query._build_pipeline.return_value = expected + instance = self._make_client().pipeline() + got = instance.create_from(mock_query) + assert got == expected + assert mock_query._build_pipeline.call_count == 1 + assert mock_query._build_pipeline.call_args_list[0][0][0] == instance + + def test_create_from_query(self): + query = self._make_query() + instance = self._make_client().pipeline() + ppl = instance.create_from(query) + assert isinstance(ppl, self._expected_pipeline_type) + assert len(ppl.stages) == 1 + + def test_collection(self): + instance = self._make_client().pipeline() + ppl = instance.collection("path") + assert isinstance(ppl, self._expected_pipeline_type) + assert len(ppl.stages) == 1 + first_stage = ppl.stages[0] + assert isinstance(first_stage, stages.Collection) + assert first_stage.path == "/path" + + def test_collection_w_tuple(self): + instance = self._make_client().pipeline() + ppl = instance.collection(("a", "b", "c")) + assert isinstance(ppl, self._expected_pipeline_type) + assert len(ppl.stages) == 1 + first_stage = ppl.stages[0] + assert isinstance(first_stage, stages.Collection) + assert first_stage.path == "/a/b/c" + + def test_collection_group(self): + instance = self._make_client().pipeline() + ppl = instance.collection_group("id") + assert isinstance(ppl, self._expected_pipeline_type) + assert len(ppl.stages) == 1 + first_stage = ppl.stages[0] + assert isinstance(first_stage, stages.CollectionGroup) + assert first_stage.collection_id == "id" + + def test_database(self): + instance = self._make_client().pipeline() + ppl = instance.database() + assert isinstance(ppl, self._expected_pipeline_type) + assert len(ppl.stages) == 1 + first_stage = ppl.stages[0] + assert isinstance(first_stage, stages.Database) + + def test_documents(self): + instance = self._make_client().pipeline() + test_documents = [ + BaseDocumentReference("a", "1"), + BaseDocumentReference("a", "2"), + BaseDocumentReference("a", "3"), + ] + ppl = instance.documents(*test_documents) + assert isinstance(ppl, self._expected_pipeline_type) + assert len(ppl.stages) == 1 + first_stage = ppl.stages[0] + assert isinstance(first_stage, stages.Documents) + assert len(first_stage.paths) == 3 + assert first_stage.paths[0] == "/a/1" + assert first_stage.paths[1] == "/a/2" + assert first_stage.paths[2] == "/a/3" + + +class TestPipelineSourceWithAsyncClient(TestPipelineSource): + """ + When an async client is used, it should produce async pipelines + """ + + _expected_pipeline_type = AsyncPipeline + + def _make_client(self): + return make_async_client() + + def _make_query(self): + return AsyncQuery(mock.Mock()) diff --git a/tests/unit/v1/test_pipeline_stages.py b/tests/unit/v1/test_pipeline_stages.py new file mode 100644 index 000000000..a2d466f47 --- /dev/null +++ b/tests/unit/v1/test_pipeline_stages.py @@ -0,0 +1,855 @@ +# Copyright 2025 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +import pytest +from unittest import mock + +from google.cloud.firestore_v1.base_pipeline import _BasePipeline +import google.cloud.firestore_v1.pipeline_stages as stages +from google.cloud.firestore_v1.pipeline_expressions import ( + Constant, + Field, + Ordering, +) +from google.cloud.firestore_v1.types.document import Value +from google.cloud.firestore_v1._helpers import GeoPoint +from google.cloud.firestore_v1.vector import Vector +from google.cloud.firestore_v1.base_vector_query import DistanceMeasure + + +class TestStage: + def test_ctor(self): + """ + Base class should be abstract + """ + with pytest.raises(TypeError): + stages.Stage() + + +class TestAddFields: + def _make_one(self, *args, **kwargs): + return stages.AddFields(*args, **kwargs) + + def test_ctor(self): + field1 = Field.of("field1") + field2_aliased = Field.of("field2").as_("alias2") + instance = self._make_one(field1, field2_aliased) + assert instance.fields == [field1, field2_aliased] + assert instance.name == "add_fields" + + def test_repr(self): + field1 = Field.of("field1").as_("f1") + instance = self._make_one(field1) + repr_str = repr(instance) + assert repr_str == "AddFields(fields=[Field.of('field1').as_('f1')])" + + def test_to_pb(self): + field1 = Field.of("field1") + field2_aliased = Field.of("field2").as_("alias2") + instance = self._make_one(field1, field2_aliased) + result = instance._to_pb() + assert result.name == "add_fields" + assert len(result.args) == 1 + expected_map_value = { + "fields": { + "field1": Value(field_reference_value="field1"), + "alias2": Value(field_reference_value="field2"), + } + } + assert result.args[0].map_value.fields == expected_map_value["fields"] + assert len(result.options) == 0 + + +class TestAggregate: + def _make_one(self, *args, **kwargs): + return stages.Aggregate(*args, **kwargs) + + def test_ctor_positional(self): + """test with only positional arguments""" + sum_total = Field.of("total").sum().as_("sum_total") + avg_price = Field.of("price").average().as_("avg_price") + instance = self._make_one(sum_total, avg_price) + assert list(instance.accumulators) == [sum_total, avg_price] + assert len(instance.groups) == 0 + assert instance.name == "aggregate" + + def test_ctor_keyword(self): + """test with only keyword arguments""" + sum_total = Field.of("total").sum().as_("sum_total") + avg_price = Field.of("price").average().as_("avg_price") + group_category = Field.of("category") + instance = self._make_one( + accumulators=[avg_price, sum_total], groups=[group_category, "city"] + ) + assert instance.accumulators == [avg_price, sum_total] + assert len(instance.groups) == 2 + assert instance.groups[0] == group_category + assert isinstance(instance.groups[1], Field) + assert instance.groups[1].path == "city" + assert instance.name == "aggregate" + + def test_ctor_combined(self): + """test with a mix of arguments""" + sum_total = Field.of("total").sum().as_("sum_total") + avg_price = Field.of("price").average().as_("avg_price") + count = Field.of("total").count().as_("count") + with pytest.raises(ValueError): + self._make_one(sum_total, accumulators=[avg_price, count]) + + def test_repr(self): + sum_total = Field.of("total").sum().as_("sum_total") + group_category = Field.of("category") + instance = self._make_one(sum_total, groups=[group_category]) + repr_str = repr(instance) + assert ( + repr_str + == "Aggregate(Field.of('total').sum().as_('sum_total'), groups=[Field.of('category')])" + ) + + def test_to_pb(self): + sum_total = Field.of("total").sum().as_("sum_total") + group_category = Field.of("category") + instance = self._make_one(sum_total, groups=[group_category]) + result = instance._to_pb() + assert result.name == "aggregate" + assert len(result.args) == 2 + + expected_accumulators_map = { + "fields": { + "sum_total": Value( + function_value={ + "name": "sum", + "args": [Value(field_reference_value="total")], + } + ) + } + } + assert result.args[0].map_value.fields == expected_accumulators_map["fields"] + + expected_groups_map = { + "fields": {"category": Value(field_reference_value="category")} + } + assert result.args[1].map_value.fields == expected_groups_map["fields"] + assert len(result.options) == 0 + + +class TestCollection: + def _make_one(self, *args, **kwargs): + return stages.Collection(*args, **kwargs) + + @pytest.mark.parametrize( + "input_arg,expected", + [ + ("test", "Collection(path='/test')"), + ("/test", "Collection(path='/test')"), + ], + ) + def test_repr(self, input_arg, expected): + instance = self._make_one(input_arg) + repr_str = repr(instance) + assert repr_str == expected + + def test_to_pb(self): + input_arg = "test/col" + instance = self._make_one(input_arg) + result = instance._to_pb() + assert result.name == "collection" + assert len(result.args) == 1 + assert result.args[0].reference_value == "/test/col" + assert len(result.options) == 0 + + +class TestCollectionGroup: + def _make_one(self, *args, **kwargs): + return stages.CollectionGroup(*args, **kwargs) + + def test_repr(self): + input_arg = "test" + instance = self._make_one(input_arg) + repr_str = repr(instance) + assert repr_str == "CollectionGroup(collection_id='test')" + + def test_to_pb(self): + input_arg = "test" + instance = self._make_one(input_arg) + result = instance._to_pb() + assert result.name == "collection_group" + assert len(result.args) == 2 + assert result.args[0].reference_value == "" + assert result.args[1].string_value == "test" + assert len(result.options) == 0 + + +class TestDatabase: + def _make_one(self, *args, **kwargs): + return stages.Database(*args, **kwargs) + + def test_ctor(self): + instance = self._make_one() + assert instance.name == "database" + + def test_repr(self): + instance = self._make_one() + repr_str = repr(instance) + assert repr_str == "Database()" + + def test_to_pb(self): + instance = self._make_one() + result = instance._to_pb() + assert result.name == "database" + assert len(result.args) == 0 + assert len(result.options) == 0 + + +class TestDistinct: + def _make_one(self, *args, **kwargs): + return stages.Distinct(*args, **kwargs) + + def test_ctor(self): + field1 = Field.of("field1") + instance = self._make_one("field2", field1) + assert len(instance.fields) == 2 + assert isinstance(instance.fields[0], Field) + assert instance.fields[0].path == "field2" + assert instance.fields[1] == field1 + assert instance.name == "distinct" + + def test_repr(self): + instance = self._make_one("field1", Field.of("field2")) + repr_str = repr(instance) + assert repr_str == "Distinct(fields=[Field.of('field1'), Field.of('field2')])" + + def test_to_pb(self): + instance = self._make_one("field1", Field.of("field2")) + result = instance._to_pb() + assert result.name == "distinct" + assert len(result.args) == 1 + expected_map_value = { + "fields": { + "field1": Value(field_reference_value="field1"), + "field2": Value(field_reference_value="field2"), + } + } + assert result.args[0].map_value.fields == expected_map_value["fields"] + assert len(result.options) == 0 + + +class TestDocuments: + def _make_one(self, *args, **kwargs): + return stages.Documents(*args, **kwargs) + + def test_ctor(self): + instance = self._make_one("/projects/p/databases/d/documents/c/doc1", "/c/doc2") + assert instance.paths == ("/projects/p/databases/d/documents/c/doc1", "/c/doc2") + assert instance.name == "documents" + + def test_of(self): + mock_doc_ref1 = mock.Mock() + mock_doc_ref1.path = "projects/p/databases/d/documents/c/doc1" + mock_doc_ref2 = mock.Mock() + mock_doc_ref2.path = "c/doc2" # Test relative path as well + instance = stages.Documents.of(mock_doc_ref1, mock_doc_ref2) + assert instance.paths == ( + "/projects/p/databases/d/documents/c/doc1", + "/c/doc2", + ) + + def test_repr(self): + instance = self._make_one("/a/b", "/c/d") + repr_str = repr(instance) + assert repr_str == "Documents('/a/b', '/c/d')" + + def test_to_pb(self): + instance = self._make_one("/projects/p/databases/d/documents/c/doc1", "/c/doc2") + result = instance._to_pb() + assert result.name == "documents" + assert len(result.args) == 2 + assert ( + result.args[0].reference_value == "/projects/p/databases/d/documents/c/doc1" + ) + assert result.args[1].reference_value == "/c/doc2" + assert len(result.options) == 0 + + +class TestFindNearest: + class TestFindNearestOptions: + def _make_one_options(self, *args, **kwargs): + return stages.FindNearestOptions(*args, **kwargs) + + def test_ctor_options(self): + limit_val = 10 + distance_field_val = Field.of("dist") + instance = self._make_one_options( + limit=limit_val, distance_field=distance_field_val + ) + assert instance.limit == limit_val + assert instance.distance_field == distance_field_val + + def test_ctor_defaults(self): + instance_default = self._make_one_options() + assert instance_default.limit is None + assert instance_default.distance_field is None + + def test_repr(self): + instance_empty = self._make_one_options() + assert repr(instance_empty) == "FindNearestOptions()" + instance_limit = self._make_one_options(limit=5) + assert repr(instance_limit) == "FindNearestOptions(limit=5)" + instance_distance = self._make_one_options(distance_field=Field.of("dist")) + assert ( + repr(instance_distance) + == "FindNearestOptions(distance_field=Field.of('dist'))" + ) + instance_full = self._make_one_options( + limit=5, distance_field=Field.of("dist") + ) + assert ( + repr(instance_full) + == "FindNearestOptions(limit=5, distance_field=Field.of('dist'))" + ) + + def _make_one(self, *args, **kwargs): + return stages.FindNearest(*args, **kwargs) + + def test_ctor_w_str_field(self): + field_path = "embedding_field" + vector_val = Vector([1.0, 2.0, 3.0]) + distance_measure_val = DistanceMeasure.EUCLIDEAN + options_val = stages.FindNearestOptions( + limit=5, distance_field=Field.of("distance") + ) + + instance_str_field = self._make_one( + field_path, vector_val, distance_measure_val, options=options_val + ) + assert isinstance(instance_str_field.field, Field) + assert instance_str_field.field.path == field_path + assert instance_str_field.vector == vector_val + assert instance_str_field.distance_measure == distance_measure_val + assert instance_str_field.options == options_val + assert instance_str_field.name == "find_nearest" + + def test_ctor_w_field_obj(self): + field_path = "embedding_field" + field_obj = Field.of(field_path) + vector_val = Vector([1.0, 2.0, 3.0]) + distance_measure_val = DistanceMeasure.EUCLIDEAN + instance_field_obj = self._make_one(field_obj, vector_val, distance_measure_val) + assert instance_field_obj.field == field_obj + assert instance_field_obj.options.limit is None # Default options + assert instance_field_obj.options.distance_field is None + + def test_ctor_w_vector_list(self): + field_path = "embedding_field" + distance_measure_val = DistanceMeasure.EUCLIDEAN + + vector_list = [4.0, 5.0] + instance_list_vector = self._make_one( + field_path, vector_list, distance_measure_val + ) + assert isinstance(instance_list_vector.vector, Vector) + assert instance_list_vector.vector == Vector(vector_list) + + def test_repr(self): + field_path = "embedding_field" + vector_val = Vector([1.0, 2.0]) + distance_measure_val = DistanceMeasure.EUCLIDEAN + options_val = stages.FindNearestOptions(limit=5) + instance = self._make_one( + field_path, vector_val, distance_measure_val, options=options_val + ) + repr_str = repr(instance) + expected_repr = "FindNearest(field=Field.of('embedding_field'), vector=Vector<1.0, 2.0>, distance_measure=, options=FindNearestOptions(limit=5))" + assert repr_str == expected_repr + + @pytest.mark.parametrize( + "distance_measure_val, expected_str", + [ + (DistanceMeasure.COSINE, "cosine"), + (DistanceMeasure.DOT_PRODUCT, "dot_product"), + (DistanceMeasure.EUCLIDEAN, "euclidean"), + ], + ) + def test_to_pb(self, distance_measure_val, expected_str): + field_path = "embedding" + vector_val = Vector([0.1, 0.2]) + options_val = stages.FindNearestOptions( + limit=7, distance_field=Field.of("dist_val") + ) + instance = self._make_one( + field_path, vector_val, distance_measure_val, options=options_val + ) + + result = instance._to_pb() + assert result.name == "find_nearest" + assert len(result.args) == 3 + # test field arg + assert result.args[0].field_reference_value == field_path + # test for vector arg + assert result.args[1].map_value.fields["__type__"].string_value == "__vector__" + assert ( + result.args[1].map_value.fields["value"].array_value.values[0].double_value + == 0.1 + ) + assert ( + result.args[1].map_value.fields["value"].array_value.values[1].double_value + == 0.2 + ) + # test for distance measure arg + assert result.args[2].string_value == expected_str + # test options + assert len(result.options) == 2 + assert result.options["limit"].integer_value == 7 + assert result.options["distance_field"].field_reference_value == "dist_val" + + def test_to_pb_no_options(self): + instance = self._make_one("emb", [1.0], DistanceMeasure.DOT_PRODUCT) + result = instance._to_pb() + assert len(result.options) == 0 + assert len(result.args) == 3 + + +class TestRawStage: + def _make_one(self, *args, **kwargs): + return stages.RawStage(*args, **kwargs) + + @pytest.mark.parametrize( + "input_args,expected_params", + [ + (("name",), []), + (("custom", Value(string_value="val")), [Value(string_value="val")]), + (("n", Value(integer_value=1)), [Value(integer_value=1)]), + (("n", Constant.of(1)), [Value(integer_value=1)]), + ( + ("n", Constant.of(True), Constant.of(False)), + [Value(boolean_value=True), Value(boolean_value=False)], + ), + ( + ("n", Constant.of(GeoPoint(1, 2))), + [Value(geo_point_value={"latitude": 1, "longitude": 2})], + ), + (("n", Constant.of(None)), [Value(null_value=0)]), + ( + ("n", Constant.of([0, 1, 2])), + [ + Value( + array_value={ + "values": [Value(integer_value=n) for n in range(3)] + } + ) + ], + ), + ( + ("n", Value(reference_value="/projects/p/databases/d/documents/doc")), + [Value(reference_value="/projects/p/databases/d/documents/doc")], + ), + ( + ("n", Constant.of({"a": "b"})), + [Value(map_value={"fields": {"a": Value(string_value="b")}})], + ), + ], + ) + def test_ctor_with_params(self, input_args, expected_params): + instance = self._make_one(*input_args) + assert instance.params == expected_params + + def test_ctor_with_options(self): + options = {"index_field": Field.of("index")} + field = Field.of("field") + alias = Field.of("alias") + standard_unnest = stages.Unnest( + field, alias, options=stages.UnnestOptions(**options) + ) + generic_unnest = stages.RawStage("unnest", field, alias, options=options) + assert standard_unnest._pb_args() == generic_unnest._pb_args() + assert standard_unnest._pb_options() == generic_unnest._pb_options() + assert standard_unnest._to_pb() == generic_unnest._to_pb() + + @pytest.mark.parametrize( + "input_args,expected", + [ + (("name",), "RawStage(name='name')"), + (("custom", Value(string_value="val")), "RawStage(name='custom')"), + ], + ) + def test_repr(self, input_args, expected): + instance = self._make_one(*input_args) + repr_str = repr(instance) + assert repr_str == expected + + def test_to_pb(self): + instance = self._make_one("name", Constant.of(True), Constant.of("test")) + result = instance._to_pb() + assert result.name == "name" + assert len(result.args) == 2 + assert result.args[0].boolean_value is True + assert result.args[1].string_value == "test" + assert len(result.options) == 0 + + +class TestLimit: + def _make_one(self, *args, **kwargs): + return stages.Limit(*args, **kwargs) + + def test_repr(self): + instance = self._make_one(10) + repr_str = repr(instance) + assert repr_str == "Limit(limit=10)" + + def test_to_pb(self): + instance = self._make_one(5) + result = instance._to_pb() + assert result.name == "limit" + assert len(result.args) == 1 + assert result.args[0].integer_value == 5 + assert len(result.options) == 0 + + +class TestOffset: + def _make_one(self, *args, **kwargs): + return stages.Offset(*args, **kwargs) + + def test_repr(self): + instance = self._make_one(20) + repr_str = repr(instance) + assert repr_str == "Offset(offset=20)" + + def test_to_pb(self): + instance = self._make_one(3) + result = instance._to_pb() + assert result.name == "offset" + assert len(result.args) == 1 + assert result.args[0].integer_value == 3 + assert len(result.options) == 0 + + +class TestRemoveFields: + def _make_one(self, *args, **kwargs): + return stages.RemoveFields(*args, **kwargs) + + def test_ctor(self): + field1 = Field.of("field1") + instance = self._make_one("field2", field1) + assert len(instance.fields) == 2 + assert isinstance(instance.fields[0], Field) + assert instance.fields[0].path == "field2" + assert instance.fields[1] == field1 + assert instance.name == "remove_fields" + + def test_repr(self): + instance = self._make_one("field1", Field.of("field2")) + repr_str = repr(instance) + assert repr_str == "RemoveFields(Field.of('field1'), Field.of('field2'))" + + def test_to_pb(self): + instance = self._make_one("field1", Field.of("field2")) + result = instance._to_pb() + assert result.name == "remove_fields" + assert len(result.args) == 2 + assert result.args[0].field_reference_value == "field1" + assert result.args[1].field_reference_value == "field2" + assert len(result.options) == 0 + + +class TestReplaceWith: + def _make_one(self, *args, **kwargs): + return stages.ReplaceWith(*args, **kwargs) + + @pytest.mark.parametrize( + "in_field,expected_field", + [ + ("test", Field.of("test")), + ("test", Field.of("test")), + ("test", Field.of("test")), + (Field.of("test"), Field.of("test")), + (Field.of("test"), Field.of("test")), + ], + ) + def test_ctor(self, in_field, expected_field): + instance = self._make_one(in_field) + assert instance.field == expected_field + assert instance.name == "replace_with" + + def test_repr(self): + instance = self._make_one("test") + repr_str = repr(instance) + assert repr_str == "ReplaceWith(field=Field.of('test'))" + + def test_to_pb(self): + instance = self._make_one(Field.of("test")) + result = instance._to_pb() + assert result.name == "replace_with" + assert len(result.args) == 2 + assert result.args[0].field_reference_value == "test" + assert result.args[1].string_value == "full_replace" + + +class TestSample: + class TestSampleOptions: + def test_ctor_percent(self): + instance = stages.SampleOptions(0.25, stages.SampleOptions.Mode.PERCENT) + assert instance.value == 0.25 + assert instance.mode == stages.SampleOptions.Mode.PERCENT + + def test_ctor_documents(self): + instance = stages.SampleOptions(10, stages.SampleOptions.Mode.DOCUMENTS) + assert instance.value == 10 + assert instance.mode == stages.SampleOptions.Mode.DOCUMENTS + + def test_percentage(self): + instance = stages.SampleOptions.percentage(1) + assert instance.value == 1 + assert instance.mode == stages.SampleOptions.Mode.PERCENT + + def test_doc_limit(self): + instance = stages.SampleOptions.doc_limit(2) + assert instance.value == 2 + assert instance.mode == stages.SampleOptions.Mode.DOCUMENTS + + def test_repr_percentage(self): + instance = stages.SampleOptions.percentage(0.5) + assert repr(instance) == "SampleOptions.percentage(0.5)" + + def test_repr_documents(self): + instance = stages.SampleOptions.doc_limit(10) + assert repr(instance) == "SampleOptions.doc_limit(10)" + + def _make_one(self, *args, **kwargs): + return stages.Sample(*args, **kwargs) + + def test_ctor_w_int(self): + instance_int = self._make_one(10) + assert isinstance(instance_int.options, stages.SampleOptions) + assert instance_int.options.value == 10 + assert instance_int.options.mode == stages.SampleOptions.Mode.DOCUMENTS + assert instance_int.name == "sample" + + def test_ctor_w_options(self): + options = stages.SampleOptions.percentage(0.5) + instance_options = self._make_one(options) + assert instance_options.options == options + assert instance_options.name == "sample" + + def test_repr(self): + instance_int = self._make_one(10) + repr_str_int = repr(instance_int) + assert repr_str_int == "Sample(options=SampleOptions.doc_limit(10))" + + options = stages.SampleOptions.percentage(0.5) + instance_options = self._make_one(options) + repr_str_options = repr(instance_options) + assert repr_str_options == "Sample(options=SampleOptions.percentage(0.5))" + + def test_to_pb_documents_mode(self): + instance_docs = self._make_one(10) + result_docs = instance_docs._to_pb() + assert result_docs.name == "sample" + assert len(result_docs.args) == 2 + assert result_docs.args[0].integer_value == 10 + assert result_docs.args[1].string_value == "documents" + assert len(result_docs.options) == 0 + + def test_to_pb_percent_mode(self): + options_percent = stages.SampleOptions.percentage(0.25) + instance_percent = self._make_one(options_percent) + result_percent = instance_percent._to_pb() + assert result_percent.name == "sample" + assert len(result_percent.args) == 2 + assert result_percent.args[0].double_value == 0.25 + assert result_percent.args[1].string_value == "percent" + assert len(result_percent.options) == 0 + + +class TestSelect: + def _make_one(self, *args, **kwargs): + return stages.Select(*args, **kwargs) + + def test_repr(self): + instance = self._make_one("field1", Field.of("field2")) + repr_str = repr(instance) + assert ( + repr_str == "Select(projections=[Field.of('field1'), Field.of('field2')])" + ) + + def test_to_pb(self): + instance = self._make_one("field1", "field2.subfield", Field.of("field3")) + result = instance._to_pb() + assert result.name == "select" + assert len(result.args) == 1 + got_map = result.args[0].map_value.fields + assert got_map.get("field1").field_reference_value == "field1" + assert got_map.get("field2.subfield").field_reference_value == "field2.subfield" + assert got_map.get("field3").field_reference_value == "field3" + assert len(result.options) == 0 + + +class TestSort: + def _make_one(self, *args, **kwargs): + return stages.Sort(*args, **kwargs) + + def test_repr(self): + order1 = Ordering(Field.of("field1"), "ASCENDING") + instance = self._make_one(order1) + repr_str = repr(instance) + assert repr_str == "Sort(orders=[Field.of('field1').ascending()])" + + def test_to_pb(self): + order1 = Ordering(Field.of("name"), "ASCENDING") + order2 = Ordering(Field.of("age"), "DESCENDING") + instance = self._make_one(order1, order2) + result = instance._to_pb() + assert result.name == "sort" + assert len(result.args) == 2 + got_map = result.args[0].map_value.fields + assert got_map.get("expression").field_reference_value == "name" + assert got_map.get("direction").string_value == "ascending" + assert len(result.options) == 0 + + +class TestUnion: + def _make_one(self, *args, **kwargs): + return stages.Union(*args, **kwargs) + + def test_ctor(self): + mock_pipeline = mock.Mock(spec=_BasePipeline) + instance = self._make_one(mock_pipeline) + assert instance.other == mock_pipeline + assert instance.name == "union" + + def test_repr(self): + test_pipeline = _BasePipeline(mock.Mock()).sample(5) + instance = self._make_one(test_pipeline) + repr_str = repr(instance) + assert repr_str == f"Union(other={test_pipeline!r})" + + def test_to_pb(self): + test_pipeline = _BasePipeline(mock.Mock()).sample(5) + + instance = self._make_one(test_pipeline) + result = instance._to_pb() + + assert result.name == "union" + assert len(result.args) == 1 + assert result.args[0].pipeline_value == test_pipeline._to_pb().pipeline + assert len(result.options) == 0 + + +class TestUnnest: + class TestUnnestOptions: + def _make_one_options(self, *args, **kwargs): + return stages.UnnestOptions(*args, **kwargs) + + def test_ctor_options(self): + index_field_val = "my_index" + instance = self._make_one_options(index_field=index_field_val) + assert isinstance(instance.index_field, Field) + assert instance.index_field.path == index_field_val + + def test_repr(self): + instance = self._make_one_options(index_field="my_idx") + repr_str = repr(instance) + assert repr_str == "UnnestOptions(index_field='my_idx')" + + def _make_one(self, *args, **kwargs): + return stages.Unnest(*args, **kwargs) + + def test_ctor(self): + instance = self._make_one("my_field") + assert isinstance(instance.field, Field) + assert instance.field.path == "my_field" + assert isinstance(instance.alias, Field) + assert instance.alias.path == "my_field" + assert instance.options is None + assert instance.name == "unnest" + + def test_ctor_full(self): + """constructor with alias and options set""" + field = Field.of("items") + alias = Field.of("alias") + options = stages.UnnestOptions(index_field="item_index") + instance = self._make_one(field, alias, options=options) + assert isinstance(field, Field) + assert instance.field == field + assert isinstance(alias, Field) + assert instance.alias == alias + assert instance.options == options + assert instance.name == "unnest" + + def test_repr(self): + instance_simple = self._make_one("my_field") + repr_str_simple = repr(instance_simple) + assert ( + repr_str_simple + == "Unnest(field=Field.of('my_field'), alias=Field.of('my_field'), options=None)" + ) + + options = stages.UnnestOptions(index_field="item_idx") + instance_full = self._make_one( + Field.of("items"), Field.of("alias"), options=options + ) + repr_str_full = repr(instance_full) + assert ( + repr_str_full + == "Unnest(field=Field.of('items'), alias=Field.of('alias'), options=UnnestOptions(index_field='item_idx'))" + ) + + def test_to_pb(self): + instance = self._make_one(Field.of("dataPoints")) + result = instance._to_pb() + assert result.name == "unnest" + assert len(result.args) == 2 + assert result.args[0].field_reference_value == "dataPoints" + assert result.args[1].field_reference_value == "dataPoints" + assert len(result.options) == 0 + + def test_to_pb_full(self): + field_str = "items" + alias_str = "single_item" + options_val = stages.UnnestOptions(index_field="item_index") + instance = self._make_one(field_str, alias_str, options=options_val) + + result = instance._to_pb() + assert result.name == "unnest" + assert len(result.args) == 2 + assert result.args[0].field_reference_value == field_str + assert result.args[1].field_reference_value == alias_str + + assert len(result.options) == 1 + assert result.options["index_field"].field_reference_value == "item_index" + + +class TestWhere: + def _make_one(self, *args, **kwargs): + return stages.Where(*args, **kwargs) + + def test_repr(self): + condition = Field.of("age").greater_than(30) + instance = self._make_one(condition) + repr_str = repr(instance) + assert ( + repr_str == "Where(condition=Field.of('age').greater_than(Constant.of(30)))" + ) + + def test_to_pb(self): + condition = Field.of("city").equal("SF") + instance = self._make_one(condition) + result = instance._to_pb() + assert result.name == "where" + assert len(result.args) == 1 + got_fn = result.args[0].function_value + assert got_fn.name == "equal" + assert len(got_fn.args) == 2 + assert got_fn.args[0].field_reference_value == "city" + assert got_fn.args[1].string_value == "SF" + assert len(result.options) == 0 diff --git a/tests/unit/v1/test_query.py b/tests/unit/v1/test_query.py index b8c37cf84..7eaeef61b 100644 --- a/tests/unit/v1/test_query.py +++ b/tests/unit/v1/test_query.py @@ -1046,3 +1046,22 @@ def test_collection_group_get_partitions_w_offset(database): query = _make_collection_group(parent).offset(10) with pytest.raises(ValueError): list(query.get_partitions(2)) + + +def test_asyncquery_collection_pipeline_type(): + from google.cloud.firestore_v1.pipeline import Pipeline + + client = make_client() + parent = client.collection("test") + query = parent._query() + ppl = query._build_pipeline(client.pipeline()) + assert isinstance(ppl, Pipeline) + + +def test_asyncquery_collectiongroup_pipeline_type(): + from google.cloud.firestore_v1.pipeline import Pipeline + + client = make_client() + query = client.collection_group("test") + ppl = query._build_pipeline(client.pipeline()) + assert isinstance(ppl, Pipeline) diff --git a/tests/unit/v1/test_query_profile.py b/tests/unit/v1/test_query_profile.py index a3b0390c6..5b1e470b8 100644 --- a/tests/unit/v1/test_query_profile.py +++ b/tests/unit/v1/test_query_profile.py @@ -124,3 +124,64 @@ def test_explain_options__to_dict(): assert ExplainOptions(analyze=True)._to_dict() == {"analyze": True} assert ExplainOptions(analyze=False)._to_dict() == {"analyze": False} + + +@pytest.mark.parametrize("mode_str", ["analyze", "explain"]) +def test_pipeline_explain_options__to_value(mode_str): + """ + Should be able to create a Value protobuf representation of ExplainOptions + """ + from google.cloud.firestore_v1.query_profile import PipelineExplainOptions + from google.cloud.firestore_v1.types.document import MapValue + from google.cloud.firestore_v1.types.document import Value + + options = PipelineExplainOptions(mode=mode_str) + expected_value = Value( + map_value=MapValue(fields={"mode": Value(string_value=mode_str)}) + ) + assert options._to_value() == expected_value + + +def test_explain_stats_get_raw(): + """ + Test ExplainStats.get_raw(). Should return input directly + """ + from google.cloud.firestore_v1.query_profile import ExplainStats + + input = object() + stats = ExplainStats(input) + assert stats.get_raw() is input + + +def test_explain_stats_get_text(): + """ + Test ExplainStats.get_text() + """ + from google.cloud.firestore_v1.query_profile import ExplainStats + from google.cloud.firestore_v1.types import explain_stats as explain_stats_pb2 + from google.protobuf import any_pb2 + from google.protobuf import wrappers_pb2 + + expected_text = "some text" + text_pb = any_pb2.Any() + text_pb.Pack(wrappers_pb2.StringValue(value=expected_text)) + expected_stats_pb = explain_stats_pb2.ExplainStats(data=text_pb) + stats = ExplainStats(expected_stats_pb) + assert stats.get_text() == expected_text + + +def test_explain_stats_get_text_error(): + """ + Test ExplainStats.get_text() raises QueryExplainError + """ + from google.cloud.firestore_v1.query_profile import ( + ExplainStats, + QueryExplainError, + ) + from google.cloud.firestore_v1.types import explain_stats as explain_stats_pb2 + + expected_stats_pb = explain_stats_pb2.ExplainStats(data={}) + stats = ExplainStats(expected_stats_pb) + with pytest.raises(QueryExplainError) as exc: + stats.get_text() + assert "Unable to decode explain stats" in str(exc.value)