Skip to content

Commit 5eff0fa

Browse files
committed
Add comparison tests
1 parent 742546d commit 5eff0fa

File tree

6 files changed

+707
-647
lines changed

6 files changed

+707
-647
lines changed

firebase-firestore/src/main/java/com/google/firebase/firestore/model/Values.kt

Lines changed: 54 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -107,15 +107,16 @@ internal object Values {
107107
}
108108
}
109109

110-
fun strictEquals(left: Value, right: Value): Boolean {
110+
fun strictEquals(left: Value, right: Value): Boolean? {
111+
if (left.hasNullValue() || right.hasNullValue()) return null
111112
val leftType = typeOrder(left)
112113
val rightType = typeOrder(right)
113114
if (leftType != rightType) {
114115
return false
115116
}
116117

117118
return when (leftType) {
118-
TYPE_ORDER_NULL -> false
119+
TYPE_ORDER_NULL -> null
119120
TYPE_ORDER_NUMBER -> strictNumberEquals(left, right)
120121
TYPE_ORDER_ARRAY -> strictArrayEquals(left, right)
121122
TYPE_ORDER_VECTOR,
@@ -127,6 +128,15 @@ internal object Values {
127128
}
128129
}
129130

131+
fun strictCompare(left: Value, right: Value): Int? {
132+
val leftType = typeOrder(left)
133+
val rightType = typeOrder(right)
134+
if (leftType != rightType) {
135+
return null
136+
}
137+
return compareInternal(leftType, left, right)
138+
}
139+
130140
@JvmStatic
131141
fun equals(left: Value?, right: Value?): Boolean {
132142
if (left === right) {
@@ -156,43 +166,50 @@ internal object Values {
156166
}
157167

158168
private fun strictNumberEquals(left: Value, right: Value): Boolean {
159-
if (left.valueTypeCase != right.valueTypeCase) {
160-
return false
161-
}
162-
return when (left.valueTypeCase) {
163-
ValueTypeCase.INTEGER_VALUE -> left.integerValue == right.integerValue
164-
ValueTypeCase.DOUBLE_VALUE -> left.doubleValue == right.doubleValue
165-
else -> false
166-
}
169+
if (left.doubleValue.isNaN() || right.doubleValue.isNaN()) return false
170+
return numberEquals(left, right)
167171
}
168172

169-
private fun numberEquals(left: Value, right: Value): Boolean {
170-
if (left.valueTypeCase != right.valueTypeCase) {
171-
return false
172-
}
173-
return when (left.valueTypeCase) {
174-
ValueTypeCase.INTEGER_VALUE -> left.integerValue == right.integerValue
173+
private fun numberEquals(left: Value, right: Value): Boolean =
174+
when (left.valueTypeCase) {
175+
ValueTypeCase.INTEGER_VALUE ->
176+
when (right.valueTypeCase) {
177+
ValueTypeCase.INTEGER_VALUE -> left.integerValue == right.integerValue
178+
ValueTypeCase.DOUBLE_VALUE -> right.doubleValue.compareTo(left.integerValue) == 0
179+
else -> false
180+
}
175181
ValueTypeCase.DOUBLE_VALUE ->
176-
doubleToLongBits(left.doubleValue) == doubleToLongBits(right.doubleValue)
182+
when (right.valueTypeCase) {
183+
ValueTypeCase.INTEGER_VALUE ->
184+
compareDoubleWithLong(left.doubleValue, right.integerValue) == 0
185+
ValueTypeCase.DOUBLE_VALUE ->
186+
doubleToLongBits(left.doubleValue) == doubleToLongBits(right.doubleValue)
187+
else -> false
188+
}
177189
else -> false
178190
}
179-
}
180191

181-
private fun strictArrayEquals(left: Value, right: Value): Boolean {
192+
private fun compareDoubleWithLong(double: Double, long: Long): Int =
193+
if (double.isNaN()) -1 else double.compareTo(long)
194+
195+
private fun strictArrayEquals(left: Value, right: Value): Boolean? {
182196
val leftArray = left.arrayValue
183197
val rightArray = right.arrayValue
184198

185199
if (leftArray.valuesCount != rightArray.valuesCount) {
186200
return false
187201
}
188202

203+
var foundNull = false
189204
for (i in 0 until leftArray.valuesCount) {
190-
if (!strictEquals(leftArray.getValues(i), rightArray.getValues(i))) {
205+
val equals = strictEquals(leftArray.getValues(i), rightArray.getValues(i))
206+
if (equals === null) {
207+
foundNull = true
208+
} else if (!equals) {
191209
return false
192210
}
193211
}
194-
195-
return true
212+
return if (foundNull) null else true
196213
}
197214

198215
private fun arrayEquals(left: Value, right: Value): Boolean {
@@ -212,22 +229,26 @@ internal object Values {
212229
return true
213230
}
214231

215-
private fun strictObjectEquals(left: Value, right: Value): Boolean {
232+
private fun strictObjectEquals(left: Value, right: Value): Boolean? {
216233
val leftMap = left.mapValue
217234
val rightMap = right.mapValue
218235

219236
if (leftMap.fieldsCount != rightMap.fieldsCount) {
220237
return false
221238
}
222239

240+
var foundNull = false
223241
for ((key, value) in leftMap.fieldsMap) {
224242
val otherEntry = rightMap.fieldsMap[key] ?: return false
225-
if (!strictEquals(value, otherEntry)) {
243+
val equals = strictEquals(value, otherEntry)
244+
if (equals === null) {
245+
foundNull = true
246+
} else if (!equals) {
226247
return false
227248
}
228249
}
229250

230-
return true
251+
return if (foundNull) null else true
231252
}
232253

233254
private fun objectEquals(left: Value, right: Value): Boolean {
@@ -268,7 +289,11 @@ internal object Values {
268289
return Util.compareIntegers(leftType, rightType)
269290
}
270291

271-
return when (leftType) {
292+
return compareInternal(leftType, left, right)
293+
}
294+
295+
private fun compareInternal(leftType: Int, left: Value, right: Value): Int =
296+
when (leftType) {
272297
TYPE_ORDER_NULL,
273298
TYPE_ORDER_MAX_VALUE -> 0
274299
TYPE_ORDER_BOOLEAN -> Util.compareBooleans(left.booleanValue, right.booleanValue)
@@ -288,7 +313,6 @@ internal object Values {
288313
TYPE_ORDER_VECTOR -> compareVectors(left.mapValue, right.mapValue)
289314
else -> throw Assert.fail("Invalid value type: $leftType")
290315
}
291-
}
292316

293317
@JvmStatic
294318
fun lowerBoundCompare(
@@ -658,14 +682,11 @@ internal object Values {
658682
@JvmStatic
659683
fun encodeValue(value: Timestamp): Value = Value.newBuilder().setTimestampValue(value).build()
660684

661-
@JvmField
662-
val TRUE_VALUE: Value = Value.newBuilder().setBooleanValue(true).build()
685+
@JvmField val TRUE_VALUE: Value = Value.newBuilder().setBooleanValue(true).build()
663686

664-
@JvmField
665-
val FALSE_VALUE: Value = Value.newBuilder().setBooleanValue(false).build()
687+
@JvmField val FALSE_VALUE: Value = Value.newBuilder().setBooleanValue(false).build()
666688

667-
@JvmStatic
668-
fun encodeValue(value: Boolean): Value = if (value) TRUE_VALUE else FALSE_VALUE
689+
@JvmStatic fun encodeValue(value: Boolean): Value = if (value) TRUE_VALUE else FALSE_VALUE
669690

670691
@JvmStatic
671692
fun encodeValue(geoPoint: GeoPoint): Value =

firebase-firestore/src/main/java/com/google/firebase/firestore/pipeline/EvaluateResult.kt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,18 @@ import com.google.protobuf.Timestamp
77

88
internal sealed class EvaluateResult(val value: Value?) {
99
abstract val isError: Boolean
10+
val isSuccess: Boolean
11+
get() = this is EvaluateResultValue
12+
val isUnset: Boolean
13+
get() = this is EvaluateResultUnset
1014

1115
companion object {
1216
val TRUE: EvaluateResultValue = EvaluateResultValue(Values.TRUE_VALUE)
1317
val FALSE: EvaluateResultValue = EvaluateResultValue(Values.FALSE_VALUE)
1418
val NULL: EvaluateResultValue = EvaluateResultValue(Values.NULL_VALUE)
1519
val DOUBLE_ZERO: EvaluateResultValue = double(0.0)
1620
val LONG_ZERO: EvaluateResultValue = long(0)
21+
fun boolean(boolean: Boolean?) = if (boolean === null) NULL else boolean(boolean)
1722
fun boolean(boolean: Boolean) = if (boolean) TRUE else FALSE
1823
fun double(double: Double) = EvaluateResultValue(encodeValue(double))
1924
fun long(long: Long) = EvaluateResultValue(encodeValue(long))

firebase-firestore/src/main/java/com/google/firebase/firestore/pipeline/evaluation.kt

Lines changed: 73 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
@file:JvmName("Evaluation")
2+
23
package com.google.firebase.firestore.pipeline
34

45
import com.google.common.math.LongMath
@@ -10,6 +11,8 @@ import com.google.firebase.firestore.model.MutableDocument
1011
import com.google.firebase.firestore.model.Values
1112
import com.google.firebase.firestore.model.Values.encodeValue
1213
import com.google.firebase.firestore.model.Values.isNanValue
14+
import com.google.firebase.firestore.model.Values.strictCompare
15+
import com.google.firebase.firestore.model.Values.strictEquals
1316
import com.google.firebase.firestore.util.Assert
1417
import com.google.firestore.v1.Value
1518
import com.google.protobuf.ByteString
@@ -76,17 +79,37 @@ internal val evaluateXor: EvaluateFunction = variadicFunction { values: BooleanA
7679

7780
// === Comparison Functions ===
7881

79-
internal val evaluateEq: EvaluateFunction = comparison(Values::strictEquals)
82+
internal val evaluateEq: EvaluateFunction = binaryFunction { p1: Value, p2: Value ->
83+
EvaluateResult.boolean(strictEquals(p1, p2))
84+
}
8085

81-
internal val evaluateNeq: EvaluateFunction = comparison { v1, v2 -> !Values.strictEquals(v1, v2) }
86+
internal val evaluateNeq: EvaluateFunction = binaryFunction { p1: Value, p2: Value ->
87+
EvaluateResult.boolean(strictEquals(p1, p2)?.not())
88+
}
8289

83-
internal val evaluateGt: EvaluateFunction = comparison { v1, v2 -> Values.compare(v1, v2) > 0 }
90+
internal val evaluateGt: EvaluateFunction = comparison { v1, v2 ->
91+
(strictCompare(v1, v2) ?: return@comparison false) > 0
92+
}
8493

85-
internal val evaluateGte: EvaluateFunction = comparison { v1, v2 -> Values.compare(v1, v2) >= 0 }
94+
internal val evaluateGte: EvaluateFunction = comparison { v1, v2 ->
95+
when (strictEquals(v1, v2)) {
96+
true -> true
97+
false -> (strictCompare(v1, v2) ?: return@comparison false) > 0
98+
null -> null
99+
}
100+
}
86101

87-
internal val evaluateLt: EvaluateFunction = comparison { v1, v2 -> Values.compare(v1, v2) < 0 }
102+
internal val evaluateLt: EvaluateFunction = comparison { v1, v2 ->
103+
(strictCompare(v1, v2) ?: return@comparison false) < 0
104+
}
88105

89-
internal val evaluateLte: EvaluateFunction = comparison { v1, v2 -> Values.compare(v1, v2) <= 0 }
106+
internal val evaluateLte: EvaluateFunction = comparison { v1, v2 ->
107+
when (strictEquals(v1, v2)) {
108+
true -> true
109+
false -> (strictCompare(v1, v2) ?: return@comparison false) < 0
110+
null -> null
111+
}
112+
}
90113

91114
internal val evaluateNot: EvaluateFunction = unaryFunction { b: Boolean ->
92115
EvaluateResult.boolean(b.not())
@@ -297,52 +320,48 @@ internal fun plus(t: Timestamp, seconds: Long, nanos: Long): Timestamp =
297320
}
298321

299322
private fun plus(t: Timestamp, seconds: Long): Timestamp =
300-
if (seconds == 0L) t
301-
else Values.timestamp(checkedAdd(t.seconds, seconds), t.nanos)
323+
if (seconds == 0L) t else Values.timestamp(checkedAdd(t.seconds, seconds), t.nanos)
302324

303325
internal fun minus(t: Timestamp, seconds: Long, nanos: Long): Timestamp =
304326
if (nanos == 0L) {
305327
minus(t, seconds)
306328
} else {
307329
val nanoSum = t.nanos - nanos // Overflow not possible since nanos is 0 to 1 000 000.
308-
val secondsSum: Long = checkedSubtract(t.seconds, checkedSubtract(seconds, nanoSum / L_NANOS_PER_SECOND))
330+
val secondsSum: Long =
331+
checkedSubtract(t.seconds, checkedSubtract(seconds, nanoSum / L_NANOS_PER_SECOND))
309332
Values.timestamp(secondsSum, (nanoSum % I_NANOS_PER_SECOND).toInt())
310333
}
311334

312335
private fun minus(t: Timestamp, seconds: Long): Timestamp =
313-
if (seconds == 0L) t
314-
else Values.timestamp(checkedSubtract(t.seconds, seconds), t.nanos)
315-
316-
317-
internal val evaluateTimestampAdd =
318-
ternaryTimestampFunction { t: Timestamp, u: String, n: Long ->
319-
EvaluateResult.timestamp(
320-
when (u) {
321-
"microsecond" -> plus(t, n / L_MICROS_PER_SECOND, (n % L_MICROS_PER_SECOND) * 1000)
322-
"millisecond" -> plus(t, n / L_MILLIS_PER_SECOND, (n % L_MILLIS_PER_SECOND) * 1000_000)
323-
"second" -> plus(t, n)
324-
"minute" -> plus(t, checkedMultiply(n, 60))
325-
"hour" -> plus(t, checkedMultiply(n, 3600))
326-
"day" -> plus(t, checkedMultiply(n, 86400))
327-
else -> return@ternaryTimestampFunction EvaluateResultError
328-
}
329-
)
330-
}
336+
if (seconds == 0L) t else Values.timestamp(checkedSubtract(t.seconds, seconds), t.nanos)
331337

332-
internal val evaluateTimestampSub =
333-
ternaryTimestampFunction { t: Timestamp, u: String, n: Long ->
334-
EvaluateResult.timestamp(
335-
when (u) {
336-
"microsecond" -> minus(t, n / L_MICROS_PER_SECOND, (n % L_MICROS_PER_SECOND) * 1000)
337-
"millisecond" -> minus(t, n / L_MILLIS_PER_SECOND, (n % L_MILLIS_PER_SECOND) * 1000_000)
338-
"second" -> minus(t, n)
339-
"minute" -> minus(t, checkedMultiply(n, 60))
340-
"hour" -> minus(t, checkedMultiply(n, 3600))
341-
"day" -> minus(t, checkedMultiply(n, 86400))
342-
else -> return@ternaryTimestampFunction EvaluateResultError
343-
}
344-
)
345-
}
338+
internal val evaluateTimestampAdd = ternaryTimestampFunction { t: Timestamp, u: String, n: Long ->
339+
EvaluateResult.timestamp(
340+
when (u) {
341+
"microsecond" -> plus(t, n / L_MICROS_PER_SECOND, (n % L_MICROS_PER_SECOND) * 1000)
342+
"millisecond" -> plus(t, n / L_MILLIS_PER_SECOND, (n % L_MILLIS_PER_SECOND) * 1000_000)
343+
"second" -> plus(t, n)
344+
"minute" -> plus(t, checkedMultiply(n, 60))
345+
"hour" -> plus(t, checkedMultiply(n, 3600))
346+
"day" -> plus(t, checkedMultiply(n, 86400))
347+
else -> return@ternaryTimestampFunction EvaluateResultError
348+
}
349+
)
350+
}
351+
352+
internal val evaluateTimestampSub = ternaryTimestampFunction { t: Timestamp, u: String, n: Long ->
353+
EvaluateResult.timestamp(
354+
when (u) {
355+
"microsecond" -> minus(t, n / L_MICROS_PER_SECOND, (n % L_MICROS_PER_SECOND) * 1000)
356+
"millisecond" -> minus(t, n / L_MILLIS_PER_SECOND, (n % L_MILLIS_PER_SECOND) * 1000_000)
357+
"second" -> minus(t, n)
358+
"minute" -> minus(t, checkedMultiply(n, 60))
359+
"hour" -> minus(t, checkedMultiply(n, 3600))
360+
"day" -> minus(t, checkedMultiply(n, 86400))
361+
else -> return@ternaryTimestampFunction EvaluateResultError
362+
}
363+
)
364+
}
346365

347366
internal val evaluateTimestampTrunc = notImplemented // TODO: Does not exist in expressions.kt yet.
348367

@@ -402,17 +421,18 @@ internal val evaluateUnixSecondsToTimestamp = unaryFunction { seconds: Long ->
402421
internal val evaluateMap: EvaluateFunction = { params ->
403422
if (params.size % 2 != 0)
404423
throw Assert.fail("Function should have even number of params, but %d were given.", params.size)
405-
else block@{ input: MutableDocument ->
406-
val map: MutableMap<String, Value> = HashMap(params.size / 2)
407-
for (i in params.indices step 2) {
408-
val k = params[i](input).value ?: return@block EvaluateResultError
409-
if (!k.hasStringValue()) return@block EvaluateResultError
410-
val v = params[i + 1](input).value ?: return@block EvaluateResultError
411-
// It is against the API contract to include a key more than once.
412-
if (map.put(k.stringValue, v) != null) return@block EvaluateResultError
424+
else
425+
block@{ input: MutableDocument ->
426+
val map: MutableMap<String, Value> = HashMap(params.size / 2)
427+
for (i in params.indices step 2) {
428+
val k = params[i](input).value ?: return@block EvaluateResultError
429+
if (!k.hasStringValue()) return@block EvaluateResultError
430+
val v = params[i + 1](input).value ?: return@block EvaluateResultError
431+
// It is against the API contract to include a key more than once.
432+
if (map.put(k.stringValue, v) != null) return@block EvaluateResultError
433+
}
434+
EvaluateResultValue(encodeValue(map))
413435
}
414-
EvaluateResultValue(encodeValue(map))
415-
}
416436
}
417437

418438
// === Helper Functions ===
@@ -688,10 +708,10 @@ private inline fun variadicFunction(
688708
}
689709
}
690710

691-
private inline fun comparison(crossinline predicate: (Value, Value) -> Boolean): EvaluateFunction =
711+
private inline fun comparison(crossinline f: (Value, Value) -> Boolean?): EvaluateFunction =
692712
binaryFunction { p1: Value, p2: Value ->
693713
if (isNanValue(p1) or isNanValue(p2)) EvaluateResult.FALSE
694-
else catch { EvaluateResult.boolean(predicate(p1, p2)) }
714+
else EvaluateResult.boolean(f(p1, p2))
695715
}
696716

697717
private inline fun arithmeticPrimitive(

0 commit comments

Comments
 (0)