|
67 | 67 |
|
68 | 68 |
|
69 | 69 | _BAD_DIR_STRING: str |
70 | | -_BAD_OP_NAN_NULL: str |
| 70 | +_BAD_OP_NAN: str |
| 71 | +_BAD_OP_NULL: str |
71 | 72 | _BAD_OP_STRING: str |
72 | 73 | _COMPARISON_OPERATORS: Dict[str, Any] |
73 | 74 | _EQ_OP: str |
| 75 | +_NEQ_OP: str |
74 | 76 | _INVALID_CURSOR_TRANSFORM: str |
75 | 77 | _INVALID_WHERE_TRANSFORM: str |
76 | 78 | _MISMATCH_CURSOR_W_ORDER_BY: str |
|
80 | 82 |
|
81 | 83 |
|
82 | 84 | _EQ_OP = "==" |
| 85 | +_NEQ_OP = "!=" |
83 | 86 | _operator_enum = StructuredQuery.FieldFilter.Operator |
84 | 87 | _COMPARISON_OPERATORS = { |
85 | 88 | "<": _operator_enum.LESS_THAN, |
86 | 89 | "<=": _operator_enum.LESS_THAN_OR_EQUAL, |
87 | 90 | _EQ_OP: _operator_enum.EQUAL, |
88 | | - "!=": _operator_enum.NOT_EQUAL, |
| 91 | + _NEQ_OP: _operator_enum.NOT_EQUAL, |
89 | 92 | ">=": _operator_enum.GREATER_THAN_OR_EQUAL, |
90 | 93 | ">": _operator_enum.GREATER_THAN, |
91 | 94 | "array_contains": _operator_enum.ARRAY_CONTAINS, |
|
104 | 107 | _operator_enum.NOT_IN, |
105 | 108 | ) |
106 | 109 | _BAD_OP_STRING = "Operator string {!r} is invalid. Valid choices are: {}." |
107 | | -_BAD_OP_NAN_NULL = 'Only an equality filter ("==") can be used with None or NaN values' |
| 110 | +_BAD_OP_NAN_NULL = 'Only equality ("==") or not-equal ("!=") filters can be used with None or NaN values' |
108 | 111 | _INVALID_WHERE_TRANSFORM = "Transforms cannot be used as where values." |
109 | 112 | _BAD_DIR_STRING = "Invalid direction {!r}. Must be one of {!r} or {!r}." |
110 | 113 | _INVALID_CURSOR_TRANSFORM = "Transforms cannot be used as cursor values." |
@@ -136,26 +139,49 @@ def _to_pb(self): |
136 | 139 | """Build the protobuf representation based on values in the filter""" |
137 | 140 |
|
138 | 141 |
|
| 142 | +def _validate_opation(op_string, value): |
| 143 | + """ |
| 144 | + Given an input operator string (e.g, '!='), and a value (e.g. None), |
| 145 | + ensure that the operator and value combination is valid, and return |
| 146 | + an approproate new operator value. A new operator will be used if |
| 147 | + the operaion is a comparison against Null or NaN |
| 148 | +
|
| 149 | + Args: |
| 150 | + op_string (Optional[str]): the requested operator |
| 151 | + value (Any): the value the operator is acting on |
| 152 | + Returns: |
| 153 | + str | StructuredQuery.UnaryFilter.Operator: operator to use in requests |
| 154 | + Raises: |
| 155 | + ValueError: if the operator and value combination is invalid |
| 156 | + """ |
| 157 | + if value is None: |
| 158 | + if op_string == _EQ_OP: |
| 159 | + return StructuredQuery.UnaryFilter.Operator.IS_NULL |
| 160 | + elif op_string == _NEQ_OP: |
| 161 | + return StructuredQuery.UnaryFilter.Operator.IS_NOT_NULL |
| 162 | + else: |
| 163 | + raise ValueError(_BAD_OP_NAN_NULL) |
| 164 | + |
| 165 | + elif _isnan(value): |
| 166 | + if op_string == _EQ_OP: |
| 167 | + return StructuredQuery.UnaryFilter.Operator.IS_NAN |
| 168 | + elif op_string == _NEQ_OP: |
| 169 | + return StructuredQuery.UnaryFilter.Operator.IS_NOT_NAN |
| 170 | + else: |
| 171 | + raise ValueError(_BAD_OP_NAN_NULL) |
| 172 | + elif isinstance(value, (transforms.Sentinel, transforms._ValueList)): |
| 173 | + raise ValueError(_INVALID_WHERE_TRANSFORM) |
| 174 | + else: |
| 175 | + return op_string |
| 176 | + |
| 177 | + |
139 | 178 | class FieldFilter(BaseFilter): |
140 | 179 | """Class representation of a Field Filter.""" |
141 | 180 |
|
142 | 181 | def __init__(self, field_path, op_string, value=None): |
143 | 182 | self.field_path = field_path |
144 | 183 | self.value = value |
145 | | - |
146 | | - if value is None: |
147 | | - if op_string != _EQ_OP: |
148 | | - raise ValueError(_BAD_OP_NAN_NULL) |
149 | | - self.op_string = StructuredQuery.UnaryFilter.Operator.IS_NULL |
150 | | - |
151 | | - elif _isnan(value): |
152 | | - if op_string != _EQ_OP: |
153 | | - raise ValueError(_BAD_OP_NAN_NULL) |
154 | | - self.op_string = StructuredQuery.UnaryFilter.Operator.IS_NAN |
155 | | - elif isinstance(value, (transforms.Sentinel, transforms._ValueList)): |
156 | | - raise ValueError(_INVALID_WHERE_TRANSFORM) |
157 | | - else: |
158 | | - self.op_string = op_string |
| 184 | + self.op_string = _validate_opation(op_string, value) |
159 | 185 |
|
160 | 186 | def _to_pb(self): |
161 | 187 | """Returns the protobuf representation, either a StructuredQuery.UnaryFilter or a StructuredQuery.FieldFilter""" |
@@ -478,22 +504,12 @@ def where( |
478 | 504 | UserWarning, |
479 | 505 | stacklevel=2, |
480 | 506 | ) |
481 | | - if value is None: |
482 | | - if op_string != _EQ_OP: |
483 | | - raise ValueError(_BAD_OP_NAN_NULL) |
484 | | - filter_pb = query.StructuredQuery.UnaryFilter( |
485 | | - field=query.StructuredQuery.FieldReference(field_path=field_path), |
486 | | - op=StructuredQuery.UnaryFilter.Operator.IS_NULL, |
487 | | - ) |
488 | | - elif _isnan(value): |
489 | | - if op_string != _EQ_OP: |
490 | | - raise ValueError(_BAD_OP_NAN_NULL) |
| 507 | + op = _validate_opation(op_string, value) |
| 508 | + if isinstance(op, StructuredQuery.UnaryFilter.Operator): |
491 | 509 | filter_pb = query.StructuredQuery.UnaryFilter( |
492 | 510 | field=query.StructuredQuery.FieldReference(field_path=field_path), |
493 | | - op=StructuredQuery.UnaryFilter.Operator.IS_NAN, |
| 511 | + op=op, |
494 | 512 | ) |
495 | | - elif isinstance(value, (transforms.Sentinel, transforms._ValueList)): |
496 | | - raise ValueError(_INVALID_WHERE_TRANSFORM) |
497 | 513 | else: |
498 | 514 | filter_pb = query.StructuredQuery.FieldFilter( |
499 | 515 | field=query.StructuredQuery.FieldReference(field_path=field_path), |
|
0 commit comments