diff --git a/google/cloud/firestore_v1/base_aggregation.py b/google/cloud/firestore_v1/base_aggregation.py index 6f392207e..521b4fb41 100644 --- a/google/cloud/firestore_v1/base_aggregation.py +++ b/google/cloud/firestore_v1/base_aggregation.py @@ -361,12 +361,8 @@ def _build_pipeline(self, source: "PipelineSource"): """ Convert this query into a Pipeline - Queries containing a `cursor` or `limit_to_last` are not currently supported - Args: source: the PipelineSource to build the pipeline off of - Raises: - - NotImplementedError: raised if the query contains a `cursor` or `limit_to_last` Returns: a Pipeline representing the query """ diff --git a/google/cloud/firestore_v1/base_collection.py b/google/cloud/firestore_v1/base_collection.py index 070e54cc4..25d07aec9 100644 --- a/google/cloud/firestore_v1/base_collection.py +++ b/google/cloud/firestore_v1/base_collection.py @@ -608,12 +608,8 @@ def _build_pipeline(self, source: "PipelineSource"): """ Convert this query into a Pipeline - Queries containing a `cursor` or `limit_to_last` are not currently supported - Args: source: the PipelineSource to build the pipeline off o - Raises: - - NotImplementedError: raised if the query contains a `cursor` or `limit_to_last` Returns: a Pipeline representing the query """ diff --git a/google/cloud/firestore_v1/base_query.py b/google/cloud/firestore_v1/base_query.py index b1b74fcf1..30e1c8fb7 100644 --- a/google/cloud/firestore_v1/base_query.py +++ b/google/cloud/firestore_v1/base_query.py @@ -1134,12 +1134,8 @@ def _build_pipeline(self, source: "PipelineSource"): """ Convert this query into a Pipeline - Queries containing a `cursor` or `limit_to_last` are not currently supported - Args: source: the PipelineSource to build the pipeline off of - Raises: - - NotImplementedError: raised if the query contains a `cursor` or `limit_to_last` Returns: a Pipeline representing the query """ @@ -1162,9 +1158,10 @@ def _build_pipeline(self, source: "PipelineSource"): # Orders orders = self._normalize_orders() + + exists = [] + orderings = [] if orders: - exists = [] - orderings = [] for order in orders: field = pipeline_expressions.Field.of(order.field.field_path) exists.append(field.exists()) @@ -1178,23 +1175,59 @@ def _build_pipeline(self, source: "PipelineSource"): # Add exists filters to match Query's implicit orderby semantics. if len(exists) == 1: ppl = ppl.where(exists[0]) - else: + elif len(exists) > 1: ppl = ppl.where(pipeline_expressions.And(*exists)) - # Add sort orderings - ppl = ppl.sort(*orderings) + if orderings: + # Normalize cursors to get the raw values corresponding to the orders + start_at_val = None + if self._start_at: + start_at_val = self._normalize_cursor(self._start_at, orders) + + end_at_val = None + if self._end_at: + end_at_val = self._normalize_cursor(self._end_at, orders) + + # If limit_to_last is set, we need to reverse the orderings to find the + # "last" N documents (which effectively become the "first" N in reverse order). + if self._limit_to_last: + actual_orderings = _reverse_orderings(orderings) + ppl = ppl.sort(*actual_orderings) + + # Apply cursor conditions. + # Cursors are translated into filter conditions (e.g., field > value) + # based on the orderings. + if start_at_val: + ppl = ppl.where( + _where_conditions_from_cursor( + start_at_val, orderings, is_start_cursor=True + ) + ) - # Cursors, Limit and Offset - if self._start_at or self._end_at or self._limit_to_last: - raise NotImplementedError( - "Query to Pipeline conversion: cursors and limit_to_last is not supported yet." - ) - else: # Limit & Offset without cursors - if self._offset: - ppl = ppl.offset(self._offset) - if self._limit: + if end_at_val: + ppl = ppl.where( + _where_conditions_from_cursor( + end_at_val, orderings, is_start_cursor=False + ) + ) + + if not self._limit_to_last: + ppl = ppl.sort(*orderings) + + if self._limit is not None: ppl = ppl.limit(self._limit) + # If we reversed the orderings for limit_to_last, we must now re-sort + # using the original orderings to return the results in the user-requested order. + if self._limit_to_last: + ppl = ppl.sort(*orderings) + elif self._limit is not None and not self._limit_to_last: + ppl = ppl.limit(self._limit) + + # Offset + if self._offset: + ppl = ppl.offset(self._offset) + return ppl def _comparator(self, doc1, doc2) -> int: @@ -1366,6 +1399,69 @@ def _cursor_pb(cursor_pair: Optional[Tuple[list, bool]]) -> Optional[Cursor]: return None +def _where_conditions_from_cursor( + cursor: Tuple[List, bool], + orderings: List[pipeline_expressions.Ordering], + is_start_cursor: bool, +) -> pipeline_expressions.BooleanExpression: + """ + Converts a cursor into a filter condition for the pipeline. + + Args: + cursor: The cursor values and the 'before' flag. + orderings: The list of ordering expressions used in the query. + is_start_cursor: True if this is a start_at/start_after cursor, False if it is an end_at/end_before cursor. + Returns: + A BooleanExpression representing the cursor condition. + """ + cursor_values, before = cursor + size = len(cursor_values) + + if is_start_cursor: + filter_func = pipeline_expressions.Expression.greater_than + else: + filter_func = pipeline_expressions.Expression.less_than + + field = orderings[size - 1].expr + value = pipeline_expressions.Constant(cursor_values[size - 1]) + + # Add condition for last bound + condition = filter_func(field, value) + + if (is_start_cursor and before) or (not is_start_cursor and not before): + # When the cursor bound is inclusive, then the last bound + # can be equal to the value, otherwise it's not equal + condition = pipeline_expressions.Or(condition, field.equal(value)) + + # Iterate backwards over the remaining bounds, adding a condition for each one + for i in range(size - 2, -1, -1): + field = orderings[i].expr + value = pipeline_expressions.Constant(cursor_values[i]) + + # For each field in the orderings, the condition is either + # a) lessThan|greaterThan the cursor value, + # b) or equal the cursor value and lessThan|greaterThan the cursor values for other fields + condition = pipeline_expressions.Or( + filter_func(field, value), + pipeline_expressions.And(field.equal(value), condition), + ) + + return condition + + +def _reverse_orderings( + orderings: List[pipeline_expressions.Ordering], +) -> List[pipeline_expressions.Ordering]: + reversed_orderings = [] + for o in orderings: + if o.order_dir == pipeline_expressions.Ordering.Direction.ASCENDING: + new_dir = "descending" + else: + new_dir = "ascending" + reversed_orderings.append(pipeline_expressions.Ordering(o.expr, new_dir)) + return reversed_orderings + + def _query_response_to_snapshot( response_pb: RunQueryResponse, collection, expected_prefix: str ) -> Optional[document.DocumentSnapshot]: diff --git a/tests/system/test_system.py b/tests/system/test_system.py index 615ff1226..bcdb8afe3 100644 --- a/tests/system/test_system.py +++ b/tests/system/test_system.py @@ -1445,7 +1445,7 @@ def test_query_stream_w_field_path(query_docs, database): verify_pipeline(query) -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_query_stream_w_start_end_cursor(query_docs, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) @@ -1459,6 +1459,7 @@ def test_query_stream_w_start_end_cursor(query_docs, database): for key, value in values: assert stored[key] == value assert value["a"] == num_vals - 2 + verify_pipeline(query) @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) @@ -1869,7 +1870,7 @@ def test_pipeline_w_read_time(query_docs, cleanup, database): assert key != new_ref.id -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_query_with_order_dot_key(client, cleanup, database): db = client collection_id = "collek" + UNIQUE_RESOURCE_ID @@ -1906,6 +1907,9 @@ def test_query_with_order_dot_key(client, cleanup, database): ) cursor_with_key_data = list(query4.stream()) assert found_data == [snap.to_dict() for snap in cursor_with_key_data] + verify_pipeline(query) + verify_pipeline(query2) + verify_pipeline(query3) @pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @@ -1999,7 +2003,7 @@ def test_collection_group_queries(client, cleanup, database): verify_pipeline(query) -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_collection_group_queries_startat_endat(client, cleanup, database): collection_group = "b" + UNIQUE_RESOURCE_ID @@ -2030,6 +2034,7 @@ def test_collection_group_queries_startat_endat(client, cleanup, database): snapshots = list(query.stream()) found = set(snapshot.id for snapshot in snapshots) assert found == set(["cg-doc2", "cg-doc3", "cg-doc4"]) + verify_pipeline(query) query = ( client.collection_group(collection_group) @@ -2040,6 +2045,7 @@ def test_collection_group_queries_startat_endat(client, cleanup, database): snapshots = list(query.stream()) found = set(snapshot.id for snapshot in snapshots) assert found == set(["cg-doc2"]) + verify_pipeline(query) @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) @@ -2860,6 +2866,7 @@ def test_repro_429(client, cleanup, database): for snapshot in query2.stream(): print(f"id: {snapshot.id}") verify_pipeline(query) + verify_pipeline(query2) @pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @@ -3019,7 +3026,7 @@ def test_count_query_stream_empty_aggregation(query, database): assert "Aggregations can not be empty" in exc_info.value.message -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_count_query_with_start_at(query, database): """ Ensure that count aggregation queries work when chained with a start_at @@ -3036,6 +3043,7 @@ def test_count_query_with_start_at(query, database): for result in count_query.stream(): for aggregation_result in result: assert aggregation_result.value == expected_count + verify_pipeline(count_query) @pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) diff --git a/tests/unit/v1/test_base_query.py b/tests/unit/v1/test_base_query.py index 4a4dac727..13756f972 100644 --- a/tests/unit/v1/test_base_query.py +++ b/tests/unit/v1/test_base_query.py @@ -2116,19 +2116,96 @@ def test__query_pipeline_order_sorts(): assert sort_stage.orders[1].order_dir == expr.Ordering.Direction.DESCENDING -def test__query_pipeline_unsupported(): +def test__query_pipeline_cursors(): + from google.cloud.firestore_v1 import pipeline_expressions as expr + client = make_client() - query_start = client.collection("my_col").start_at({"field_a": "value"}) - with pytest.raises(NotImplementedError, match="cursors"): - query_start._build_pipeline(client.pipeline()) + query_start = ( + client.collection("my_col").order_by("field_a").start_at({"field_a": "value"}) + ) + pipeline = query_start._build_pipeline(client.pipeline()) + + # Expected stages: Collection, Exists, Where(Cursor), Sort + assert len(pipeline.stages) == 4 + assert pipeline.stages[0].path == "/my_col" + assert isinstance(pipeline.stages[1], stages.Where) # Exists + + cursor_stage = pipeline.stages[2] + assert isinstance(cursor_stage, stages.Where) + condition = cursor_stage.condition + # start_at({"field_a": "value"}) -> field_a >= "value" + # Implemented as Or(field_a > "value", field_a == "value") + assert isinstance(condition, expr.Or) + assert len(condition.params) == 2 + assert condition.params[0].name == "greater_than" + assert isinstance(condition.params[0].params[0], expr.Field) + assert condition.params[0].params[0].path == "field_a" + assert condition.params[1].name == "equal" + + assert isinstance(pipeline.stages[3], stages.Sort) + + query_end = ( + client.collection("my_col").order_by("field_a").end_at({"field_a": "value"}) + ) + pipeline = query_end._build_pipeline(client.pipeline()) + + # Expected stages: Collection, Exists, Where(Cursor), Sort + assert len(pipeline.stages) == 4 + assert pipeline.stages[0].path == "/my_col" + assert isinstance(pipeline.stages[1], stages.Where) # Exists + + cursor_stage = pipeline.stages[2] + assert isinstance(cursor_stage, stages.Where) + condition = cursor_stage.condition + # end_at({"field_a": "value"}) -> field_a <= "value" + # Implemented as Or(field_a < "value", field_a == "value") + assert isinstance(condition, expr.Or) + assert len(condition.params) == 2 + assert condition.params[0].name == "less_than" + assert isinstance(condition.params[0].params[0], expr.Field) + assert condition.params[0].params[0].path == "field_a" + assert condition.params[1].name == "equal" + + assert isinstance(pipeline.stages[3], stages.Sort) + + +def test__query_pipeline_limit_to_last(): + from google.cloud.firestore_v1 import pipeline_expressions as expr + + client = make_client() + query_limit_last = ( + client.collection("my_col").order_by("field_a").limit_to_last(10) + ) + pipeline = query_limit_last._build_pipeline(client.pipeline()) + # stages: collection, exists, sort(desc), limit, sort(asc) + + assert len(pipeline.stages) == 5 - query_end = client.collection("my_col").end_at({"field_a": "value"}) - with pytest.raises(NotImplementedError, match="cursors"): - query_end._build_pipeline(client.pipeline()) + # 0. Collection + assert pipeline.stages[0].path == "/my_col" - query_limit_last = client.collection("my_col").limit_to_last(10) - with pytest.raises(NotImplementedError, match="limit_to_last"): - query_limit_last._build_pipeline(client.pipeline()) + # 1. Exists + exists_stage = pipeline.stages[1] + assert isinstance(exists_stage, stages.Where) + + # 2. Sort DESCENDING (reversed) + sort_desc = pipeline.stages[2] + assert isinstance(sort_desc, stages.Sort) + assert len(sort_desc.orders) == 1 + assert sort_desc.orders[0].expr.path == "field_a" + assert sort_desc.orders[0].order_dir == expr.Ordering.Direction.DESCENDING + + # 3. Limit + limit_stage = pipeline.stages[3] + assert isinstance(limit_stage, stages.Limit) + assert limit_stage.limit == 10 + + # 4. Sort ASCENDING (original) + sort_asc = pipeline.stages[4] + assert isinstance(sort_asc, stages.Sort) + assert len(sort_asc.orders) == 1 + assert sort_asc.orders[0].expr.path == "field_a" + assert sort_asc.orders[0].order_dir == expr.Ordering.Direction.ASCENDING def test__query_pipeline_limit(): @@ -2298,3 +2375,99 @@ def _make_snapshot(docref, values): from google.cloud.firestore_v1 import document return document.DocumentSnapshot(docref, values, True, None, None, None) + + +def test__build_pipeline_limit_to_last_ordering(): + from google.cloud.firestore_v1 import pipeline_expressions as expr + + # Verify that for limit_to_last=True: + # 1. Sort (reversed) + # 2. Where (cursor condition) + + client = make_client() + # Query: Order by 'a' ASC, StartAt(10), LimitToLast(5) + query = ( + client.collection("my_col").order_by("a").start_at({"a": 10}).limit_to_last(5) + ) + + pipeline = query._build_pipeline(client.pipeline()) + + # Expected stages: + # 0. Collection + # 1. Exists (for 'a') + # 2. Sort (DESCENDING) -> This must come BEFORE the cursor filter + # 3. Where (a > 10 condition or similar) + # 4. Limit (5) + # 5. Sort (ASCENDING) + + assert len(pipeline.stages) >= 4 + + # Find indices + sort_reversed_idx = -1 + cursor_where_idx = -1 + + for i, stage in enumerate(pipeline.stages): + if isinstance(stage, stages.Sort): + # Check if it is the reversed sort (DESCENDING) + if ( + len(stage.orders) > 0 + and stage.orders[0].order_dir == expr.Ordering.Direction.DESCENDING + ): + if sort_reversed_idx == -1: + sort_reversed_idx = i + + if isinstance(stage, stages.Where): + # Check if this is the cursor condition. + # Cursor condition for start_at({"a": 10}) should be related to 'a' and 10. + # usually an OR or Comparison. + # The Exists filter is also a Where, but it's usually `exists(a)`. + + # Simple check: The condition is not just an 'exists' function call. + cond = stage.condition + if not (hasattr(cond, "name") and cond.name == "exists"): + # Assume this is the cursor filter + cursor_where_idx = i + + assert sort_reversed_idx != -1, "Reversed sort stage not found" + assert cursor_where_idx != -1, "Cursor filter stage not found" + + # Reversed Sort must happen BEFORE Cursor Filter + assert sort_reversed_idx < cursor_where_idx + + +def test__build_pipeline_normal_ordering(): + from google.cloud.firestore_v1 import pipeline_expressions as expr + + # Verify that for limit_to_last=False (Normal): + # 1. Where (cursor condition) + # 2. Sort + + client = make_client() + # Query: Order by 'a' ASC, StartAt(10) + query = client.collection("my_col").order_by("a").start_at({"a": 10}) + + pipeline = query._build_pipeline(client.pipeline()) + + # Expected stages: + # 0. Collection + # 1. Exists (for 'a') + # 2. Where (cursor condition) + # 3. Sort (ASCENDING) + + sort_idx = -1 + cursor_where_idx = -1 + + for i, stage in enumerate(pipeline.stages): + if isinstance(stage, stages.Sort): + sort_idx = i + + if isinstance(stage, stages.Where): + cond = stage.condition + if not (hasattr(cond, "name") and cond.name == "exists"): + cursor_where_idx = i + + assert sort_idx != -1, "Sort stage not found" + assert cursor_where_idx != -1, "Cursor filter stage not found" + + # Cursor Filter must happen BEFORE Sort + assert cursor_where_idx < sort_idx