diff --git a/google/cloud/firestore_v1/base_aggregation.py b/google/cloud/firestore_v1/base_aggregation.py index 6f392207e..521b4fb41 100644 --- a/google/cloud/firestore_v1/base_aggregation.py +++ b/google/cloud/firestore_v1/base_aggregation.py @@ -361,12 +361,8 @@ def _build_pipeline(self, source: "PipelineSource"): """ Convert this query into a Pipeline - Queries containing a `cursor` or `limit_to_last` are not currently supported - Args: source: the PipelineSource to build the pipeline off of - Raises: - - NotImplementedError: raised if the query contains a `cursor` or `limit_to_last` Returns: a Pipeline representing the query """ diff --git a/google/cloud/firestore_v1/base_collection.py b/google/cloud/firestore_v1/base_collection.py index 070e54cc4..25d07aec9 100644 --- a/google/cloud/firestore_v1/base_collection.py +++ b/google/cloud/firestore_v1/base_collection.py @@ -608,12 +608,8 @@ def _build_pipeline(self, source: "PipelineSource"): """ Convert this query into a Pipeline - Queries containing a `cursor` or `limit_to_last` are not currently supported - Args: source: the PipelineSource to build the pipeline off o - Raises: - - NotImplementedError: raised if the query contains a `cursor` or `limit_to_last` Returns: a Pipeline representing the query """ diff --git a/google/cloud/firestore_v1/base_query.py b/google/cloud/firestore_v1/base_query.py index b1b74fcf1..54a1f1618 100644 --- a/google/cloud/firestore_v1/base_query.py +++ b/google/cloud/firestore_v1/base_query.py @@ -1134,12 +1134,8 @@ def _build_pipeline(self, source: "PipelineSource"): """ Convert this query into a Pipeline - Queries containing a `cursor` or `limit_to_last` are not currently supported - Args: source: the PipelineSource to build the pipeline off of - Raises: - - NotImplementedError: raised if the query contains a `cursor` or `limit_to_last` Returns: a Pipeline representing the query """ @@ -1161,39 +1157,61 @@ def _build_pipeline(self, source: "PipelineSource"): ppl = ppl.select(*[field.field_path for field in self._projection.fields]) # Orders - orders = self._normalize_orders() - if orders: - exists = [] - orderings = [] - for order in orders: - field = pipeline_expressions.Field.of(order.field.field_path) - exists.append(field.exists()) - direction = ( - "ascending" - if order.direction == StructuredQuery.Direction.ASCENDING - else "descending" - ) - orderings.append(pipeline_expressions.Ordering(field, direction)) - # Add exists filters to match Query's implicit orderby semantics. - if len(exists) == 1: - ppl = ppl.where(exists[0]) - else: - ppl = ppl.where(pipeline_expressions.And(*exists)) + # "explicit_orders" are only those explicitly added by the user via order_by(). + # We only generate existence filters for these fields. + if self._orders: + exists = [ + pipeline_expressions.Field.of(o.field.field_path).exists() + for o in self._orders + ] + ppl = ppl.where( + pipeline_expressions.And(*exists) if len(exists) > 1 else exists[0] + ) + + # "normalized_orders" includes both user-specified orders and implicit orders + # (e.g. __name__ or inequality fields) required by Firestore semantics. + normalized_orders = self._normalize_orders() + orderings = [ + pipeline_expressions.Ordering( + pipeline_expressions.Field.of(o.field.field_path), + "ascending" + if o.direction == StructuredQuery.Direction.ASCENDING + else "descending", + ) + for o in normalized_orders + ] + + # Apply cursors as filters. + if orderings: + for cursor, is_start in [(self._start_at, True), (self._end_at, False)]: + cursor = self._normalize_cursor(cursor, normalized_orders) + if cursor: + ppl = ppl.where( + _where_conditions_from_cursor(cursor, orderings, is_start) + ) + + # Handle sort and limit, including limit_to_last semantics. + is_limit_to_last = self._limit_to_last and bool(orderings) - # Add sort orderings + if is_limit_to_last: + # If limit_to_last is set, we need to reverse the orderings to find the + # "last" N documents (which effectively become the "first" N in reverse order). + ppl = ppl.sort(*_reverse_orderings(orderings)) + elif orderings: ppl = ppl.sort(*orderings) - # Cursors, Limit and Offset - if self._start_at or self._end_at or self._limit_to_last: - raise NotImplementedError( - "Query to Pipeline conversion: cursors and limit_to_last is not supported yet." - ) - else: # Limit & Offset without cursors - if self._offset: - ppl = ppl.offset(self._offset) - if self._limit: - ppl = ppl.limit(self._limit) + if self._limit is not None and (not self._limit_to_last or orderings): + ppl = ppl.limit(self._limit) + + if is_limit_to_last: + # If we reversed the orderings for limit_to_last, we must now re-sort + # using the original orderings to return the results in the user-requested order. + ppl = ppl.sort(*orderings) + + # Offset + if self._offset: + ppl = ppl.offset(self._offset) return ppl @@ -1366,6 +1384,91 @@ def _cursor_pb(cursor_pair: Optional[Tuple[list, bool]]) -> Optional[Cursor]: return None +def _get_cursor_exclusive_condition( + is_start_cursor: bool, + ordering: pipeline_expressions.Ordering, + value: pipeline_expressions.Constant, +) -> pipeline_expressions.BooleanExpression: + """ + Helper to determine the correct comparison operator (greater_than or less_than) + based on the cursor type (start/end) and the sort direction (ascending/descending). + """ + field = ordering.expr + if ( + is_start_cursor + and ordering.order_dir == pipeline_expressions.Ordering.Direction.ASCENDING + ) or ( + not is_start_cursor + and ordering.order_dir == pipeline_expressions.Ordering.Direction.DESCENDING + ): + return field.greater_than(value) + else: + return field.less_than(value) + + +def _where_conditions_from_cursor( + cursor: Tuple[List, bool], + orderings: List[pipeline_expressions.Ordering], + is_start_cursor: bool, +) -> pipeline_expressions.BooleanExpression: + """ + Converts a cursor into a filter condition for the pipeline. + + Args: + cursor: The cursor values and the 'before' flag. + orderings: The list of ordering expressions used in the query. + is_start_cursor: True if this is a start_at/start_after cursor, False if it is an end_at/end_before cursor. + Returns: + A BooleanExpression representing the cursor condition. + """ + cursor_values, before = cursor + size = len(cursor_values) + + ordering = orderings[size - 1] + field = ordering.expr + value = pipeline_expressions.Constant(cursor_values[size - 1]) + + # Add condition for last bound + condition = _get_cursor_exclusive_condition(is_start_cursor, ordering, value) + + if (is_start_cursor and before) or (not is_start_cursor and not before): + # When the cursor bound is inclusive, then the last bound + # can be equal to the value, otherwise it's not equal + condition = pipeline_expressions.Or(condition, field.equal(value)) + + # Iterate backwards over the remaining bounds, adding a condition for each one + for i in range(size - 2, -1, -1): + ordering = orderings[i] + field = ordering.expr + value = pipeline_expressions.Constant(cursor_values[i]) + + # For each field in the orderings, the condition is either + # a) lessThan|greaterThan the cursor value, + # b) or equal the cursor value and lessThan|greaterThan the cursor values for other fields + exclusive_condition = _get_cursor_exclusive_condition( + is_start_cursor, ordering, value + ) + condition = pipeline_expressions.Or( + exclusive_condition, + pipeline_expressions.And(field.equal(value), condition), + ) + + return condition + + +def _reverse_orderings( + orderings: List[pipeline_expressions.Ordering], +) -> List[pipeline_expressions.Ordering]: + reversed_orderings = [] + for o in orderings: + if o.order_dir == pipeline_expressions.Ordering.Direction.ASCENDING: + new_dir = "descending" + else: + new_dir = "ascending" + reversed_orderings.append(pipeline_expressions.Ordering(o.expr, new_dir)) + return reversed_orderings + + def _query_response_to_snapshot( response_pb: RunQueryResponse, collection, expected_prefix: str ) -> Optional[document.DocumentSnapshot]: diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index c0ff3923a..b01dc340d 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -1833,7 +1833,10 @@ def _from_query_filter_pb(filter_pb, client): elif filter_pb.op == Query_pb.FieldFilter.Operator.EQUAL: return And(field.exists(), field.equal(value)) elif filter_pb.op == Query_pb.FieldFilter.Operator.NOT_EQUAL: - return And(field.exists(), field.not_equal(value)) + # In Enterprise DBs NOT_EQUAL will match a field that does not exist, + # therefore we do not want an existence filter for the NOT_EQUAL conversion + # so the Query and Pipeline behavior are consistent in Enterprise. + return field.not_equal(value) if filter_pb.op == Query_pb.FieldFilter.Operator.ARRAY_CONTAINS: return And(field.exists(), field.array_contains(value)) elif filter_pb.op == Query_pb.FieldFilter.Operator.ARRAY_CONTAINS_ANY: @@ -1841,7 +1844,10 @@ def _from_query_filter_pb(filter_pb, client): elif filter_pb.op == Query_pb.FieldFilter.Operator.IN: return And(field.exists(), field.equal_any(value)) elif filter_pb.op == Query_pb.FieldFilter.Operator.NOT_IN: - return And(field.exists(), field.not_equal_any(value)) + # In Enterprise DBs NOT_IN will match a field that does not exist, + # therefore we do not want an existence filter for the NOT_IN conversion + # so the Query and Pipeline behavior are consistent in Enterprise. + return field.not_equal_any(value) else: raise TypeError(f"Unexpected FieldFilter operator type: {filter_pb.op}") elif isinstance(filter_pb, Query_pb.Filter): diff --git a/google/cloud/firestore_v1/pipeline_source.py b/google/cloud/firestore_v1/pipeline_source.py index 8f3c0a626..9aded4f75 100644 --- a/google/cloud/firestore_v1/pipeline_source.py +++ b/google/cloud/firestore_v1/pipeline_source.py @@ -57,12 +57,8 @@ def create_from( """ Create a pipeline from an existing query - Queries containing a `cursor` or `limit_to_last` are not currently supported - Args: query: the query to build the pipeline off of - Raises: - - NotImplementedError: raised if the query contains a `cursor` or `limit_to_last` Returns: a new pipeline instance representing the query """ diff --git a/tests/unit/v1/test_base_query.py b/tests/unit/v1/test_base_query.py index 4a4dac727..a653d07d7 100644 --- a/tests/unit/v1/test_base_query.py +++ b/tests/unit/v1/test_base_query.py @@ -2116,19 +2116,94 @@ def test__query_pipeline_order_sorts(): assert sort_stage.orders[1].order_dir == expr.Ordering.Direction.DESCENDING -def test__query_pipeline_unsupported(): +def test__query_pipeline_cursors(): + from google.cloud.firestore_v1 import pipeline_expressions as expr + + client = make_client() + query_start = ( + client.collection("my_col").order_by("field_a").start_at({"field_a": 10}) + ) + pipeline = query_start._build_pipeline(client.pipeline()) + + # Stages: + # 0: Collection + # 1: Where (exists field_a) - Generated because field_a is explicitly ordered + # 2: Where (cursor condition) + # 3: Sort (field_a) + assert len(pipeline.stages) == 4 + + where_stage = pipeline.stages[2] + assert isinstance(where_stage, stages.Where) + # Expected: (field_a > 10) OR (field_a == 10) + assert isinstance(where_stage.condition, expr.Or) + params = where_stage.condition.params + assert len(params) == 2 + assert params[0].name == "greater_than" + assert params[1].name == "equal" + + +def test__query_pipeline_limit_to_last(): + from google.cloud.firestore_v1 import pipeline_expressions as expr + + client = make_client() + query = client.collection("my_col").order_by("field_a").limit_to_last(5) + pipeline = query._build_pipeline(client.pipeline()) + + # Stages: + # 0: Collection + # 1: Where (exists field_a) + # 2: Sort (field_a DESC) - Reversed + # 3: Limit (5) + # 4: Sort (field_a ASC) - Restored + assert len(pipeline.stages) == 5 + + sort_reversed = pipeline.stages[2] + assert isinstance(sort_reversed, stages.Sort) + assert sort_reversed.orders[0].order_dir == expr.Ordering.Direction.DESCENDING + + limit_stage = pipeline.stages[3] + assert isinstance(limit_stage, stages.Limit) + assert limit_stage.limit == 5 + + sort_restored = pipeline.stages[4] + assert isinstance(sort_restored, stages.Sort) + assert sort_restored.orders[0].order_dir == expr.Ordering.Direction.ASCENDING + + +def test__query_pipeline_limit_to_last_descending(): + from google.cloud.firestore_v1 import pipeline_expressions as expr + from google.cloud.firestore_v1.base_query import BaseQuery + client = make_client() - query_start = client.collection("my_col").start_at({"field_a": "value"}) - with pytest.raises(NotImplementedError, match="cursors"): - query_start._build_pipeline(client.pipeline()) + # User orders by field_a DESCENDING + query = ( + client.collection("my_col") + .order_by("field_a", direction=BaseQuery.DESCENDING) + .limit_to_last(5) + ) + pipeline = query._build_pipeline(client.pipeline()) - query_end = client.collection("my_col").end_at({"field_a": "value"}) - with pytest.raises(NotImplementedError, match="cursors"): - query_end._build_pipeline(client.pipeline()) + # Stages: + # 0: Collection + # 1: Where (exists field_a) + # 2: Sort (field_a ASCENDING) - Reversed from DESCENDING + # 3: Limit (5) + # 4: Sort (field_a DESCENDING) - Restored to original + assert len(pipeline.stages) == 5 - query_limit_last = client.collection("my_col").limit_to_last(10) - with pytest.raises(NotImplementedError, match="limit_to_last"): - query_limit_last._build_pipeline(client.pipeline()) + sort_reversed = pipeline.stages[2] + assert isinstance(sort_reversed, stages.Sort) + # Should be ASCENDING because original was DESCENDING + assert sort_reversed.orders[0].order_dir == expr.Ordering.Direction.ASCENDING + + limit_stage = pipeline.stages[3] + assert isinstance(limit_stage, stages.Limit) + assert limit_stage.limit == 5 + + sort_restored = pipeline.stages[4] + assert isinstance(sort_restored, stages.Sort) + # Should be DESCENDING (original) + assert sort_restored.orders[0].order_dir == expr.Ordering.Direction.DESCENDING def test__query_pipeline_limit(): @@ -2298,3 +2373,140 @@ def _make_snapshot(docref, values): from google.cloud.firestore_v1 import document return document.DocumentSnapshot(docref, values, True, None, None, None) + + +def test__where_conditions_from_cursor_descending(): + from google.cloud.firestore_v1.base_query import _where_conditions_from_cursor + from google.cloud.firestore_v1 import pipeline_expressions + + # Create ordering: field DESC + field_expr = pipeline_expressions.Field.of("field") + ordering = pipeline_expressions.Ordering(field_expr, "descending") + + # Case 1: StartAt (inclusive) -> <= 10 + cursor = ([10], True) + condition = _where_conditions_from_cursor(cursor, [ordering], is_start_cursor=True) + # Expected: field < 10 OR field == 10 + expected = pipeline_expressions.Or(field_expr.less_than(10), field_expr.equal(10)) + assert condition == expected + + # Case 2: StartAfter (exclusive) -> < 10 + cursor = ([10], False) + condition = _where_conditions_from_cursor(cursor, [ordering], is_start_cursor=True) + # Expected: field < 10 + expected = field_expr.less_than(10) + assert condition == expected + + # Case 3: EndAt (inclusive) -> >= 10 + cursor = ([10], False) + condition = _where_conditions_from_cursor(cursor, [ordering], is_start_cursor=False) + # Expected: field > 10 OR field == 10 + expected = pipeline_expressions.Or( + field_expr.greater_than(10), field_expr.equal(10) + ) + assert condition == expected + + # Case 4: EndBefore (exclusive) -> > 10 + cursor = ([10], True) + condition = _where_conditions_from_cursor(cursor, [ordering], is_start_cursor=False) + # Expected: field > 10 + expected = field_expr.greater_than(10) + assert condition == expected + + +def test__query_pipeline_end_at(): + from google.cloud.firestore_v1 import pipeline_expressions as expr + + client = make_client() + query_end = client.collection("my_col").order_by("field_a").end_at({"field_a": 10}) + pipeline = query_end._build_pipeline(client.pipeline()) + + # Stages: + # 0: Collection + # 1: Where (exists field_a) + # 2: Where (cursor condition) + # 3: Sort (field_a) + assert len(pipeline.stages) == 4 + + where_stage = pipeline.stages[2] + assert isinstance(where_stage, stages.Where) + # Expected: (field_a < 10) OR (field_a == 10) + assert isinstance(where_stage.condition, expr.Or) + params = where_stage.condition.params + assert len(params) == 2 + assert params[0].name == "less_than" + assert params[1].name == "equal" + + +def test__where_conditions_from_cursor_multi_field(): + from google.cloud.firestore_v1.base_query import _where_conditions_from_cursor + from google.cloud.firestore_v1 import pipeline_expressions as expr + + # Order by: A ASC, B DESC + field_a = expr.Field.of("A") + field_b = expr.Field.of("B") + ordering_a = expr.Ordering(field_a, "ascending") + ordering_b = expr.Ordering(field_b, "descending") + orderings = [ordering_a, ordering_b] + + # Cursor: A=1, B=2. StartAt (inclusive) + # Logic: (A > 1) OR (A == 1 AND (B < 2 OR B == 2)) + # Note: B is DESC, so start_at means <= 2 + cursor = ([1, 2], True) + + condition = _where_conditions_from_cursor(cursor, orderings, is_start_cursor=True) + + # Verify structure: Or(A > 1, And(A == 1, Or(B < 2, B == 2))) + assert isinstance(condition, expr.Or) + # First term: A > 1 + term1 = condition.params[0] + assert term1.name == "greater_than" + assert term1.params[0] == field_a + assert term1.params[1] == expr.Constant(1) + + # Second term: And(...) + term2 = condition.params[1] + assert isinstance(term2, expr.And) + + # Inside And: A == 1 + sub_term1 = term2.params[0] + assert sub_term1.name == "equal" + assert sub_term1.params[0] == field_a + assert sub_term1.params[1] == expr.Constant(1) + + # Inside And: Or(B < 2, B == 2) <-- DESCENDING logic check + sub_term2 = term2.params[1] + assert isinstance(sub_term2, expr.Or) + + # B < 2 (because DESC start_at) + sub_sub_term1 = sub_term2.params[0] + assert sub_sub_term1.name == "less_than" + assert sub_sub_term1.params[0] == field_b + assert sub_sub_term1.params[1] == expr.Constant(2) + + # B == 2 + sub_sub_term2 = sub_term2.params[1] + assert sub_sub_term2.name == "equal" + assert sub_sub_term2.params[0] == field_b + assert sub_sub_term2.params[1] == expr.Constant(2) + + +def test__reverse_orderings_descending(): + from google.cloud.firestore_v1.base_query import _reverse_orderings + from google.cloud.firestore_v1 import pipeline_expressions as expr + + # Input: A ASC, B DESC + field_a = expr.Field.of("A") + field_b = expr.Field.of("B") + ord_a = expr.Ordering(field_a, "ascending") + ord_b = expr.Ordering(field_b, "descending") + + reversed_ords = _reverse_orderings([ord_a, ord_b]) + + assert len(reversed_ords) == 2 + # Expect: A DESC, B ASC + assert reversed_ords[0].expr == field_a + assert reversed_ords[0].order_dir == expr.Ordering.Direction.DESCENDING + + assert reversed_ords[1].expr == field_b + assert reversed_ords[1].order_dir == expr.Ordering.Direction.ASCENDING diff --git a/tests/unit/v1/test_pipeline_expressions.py b/tests/unit/v1/test_pipeline_expressions.py index e2c6dcd0f..258f0eedf 100644 --- a/tests/unit/v1/test_pipeline_expressions.py +++ b/tests/unit/v1/test_pipeline_expressions.py @@ -463,58 +463,72 @@ def test__from_query_filter_pb_unary_filter_unknown_op(self, mock_client): BooleanExpression._from_query_filter_pb(wrapped_filter_pb, mock_client) @pytest.mark.parametrize( - "op_enum, value, expected_expr_func", + "op_enum, value, expected_expr_func, expects_existance", [ ( query_pb.StructuredQuery.FieldFilter.Operator.LESS_THAN, 10, Expression.less_than, + True, ), ( query_pb.StructuredQuery.FieldFilter.Operator.LESS_THAN_OR_EQUAL, 10, Expression.less_than_or_equal, + True, ), ( query_pb.StructuredQuery.FieldFilter.Operator.GREATER_THAN, 10, Expression.greater_than, + True, ), ( query_pb.StructuredQuery.FieldFilter.Operator.GREATER_THAN_OR_EQUAL, 10, Expression.greater_than_or_equal, + True, + ), + ( + query_pb.StructuredQuery.FieldFilter.Operator.EQUAL, + 10, + Expression.equal, + True, ), - (query_pb.StructuredQuery.FieldFilter.Operator.EQUAL, 10, Expression.equal), ( query_pb.StructuredQuery.FieldFilter.Operator.NOT_EQUAL, 10, Expression.not_equal, + False, ), ( query_pb.StructuredQuery.FieldFilter.Operator.ARRAY_CONTAINS, 10, Expression.array_contains, + True, ), ( query_pb.StructuredQuery.FieldFilter.Operator.ARRAY_CONTAINS_ANY, [10, 20], Expression.array_contains_any, + True, ), ( query_pb.StructuredQuery.FieldFilter.Operator.IN, [10, 20], Expression.equal_any, + True, ), ( query_pb.StructuredQuery.FieldFilter.Operator.NOT_IN, [10, 20], Expression.not_equal_any, + False, ), ], ) def test__from_query_filter_pb_field_filter( - self, mock_client, op_enum, value, expected_expr_func + self, mock_client, op_enum, value, expected_expr_func, expects_existance ): """ test supported field filters @@ -536,10 +550,11 @@ def test__from_query_filter_pb_field_filter( [Constant(e) for e in value] if isinstance(value, list) else Constant(value) ) expected_condition = expected_expr_func(field_expr, value) - # should include existance checks - expected = expr.And(field_expr.exists(), expected_condition) + if expects_existance: + # some expressions include extra existance checks + expected_condition = expr.And(field_expr.exists(), expected_condition) - assert repr(result) == repr(expected) + assert repr(result) == repr(expected_condition) def test__from_query_filter_pb_field_filter_unknown_op(self, mock_client): """