@@ -3,6 +3,7 @@ package com.google.firebase.firestore.pipeline
3
3
import com.google.common.math.LongMath
4
4
import com.google.common.math.LongMath.checkedAdd
5
5
import com.google.common.math.LongMath.checkedMultiply
6
+ import com.google.common.math.LongMath.checkedSubtract
6
7
import com.google.firebase.firestore.UserDataReader
7
8
import com.google.firebase.firestore.model.MutableDocument
8
9
import com.google.firebase.firestore.model.Values
@@ -273,9 +274,71 @@ internal val evaluateStrJoin = notImplemented // TODO: Does not exist in express
273
274
274
275
// === Date / Timestamp Functions ===
275
276
276
- internal val evaluateTimestampAdd = notImplemented
277
+ private const val L_NANOS_PER_SECOND : Long = 1000_000_000
278
+ private const val I_NANOS_PER_SECOND : Int = 1000_000_000
277
279
278
- internal val evaluateTimestampSub = notImplemented
280
+ private const val L_MICROS_PER_SECOND : Long = 1000_000
281
+ private const val I_MICROS_PER_SECOND : Int = 1000_000
282
+
283
+ private const val L_MILLIS_PER_SECOND : Long = 1000
284
+ private const val I_MILLIS_PER_SECOND : Int = 1000
285
+
286
+ internal fun plus (t : Timestamp , seconds : Long , nanos : Long ): Timestamp =
287
+ if (nanos == 0L ) {
288
+ plus(t, seconds)
289
+ } else {
290
+ val nanoSum = t.nanos + nanos // Overflow not possible since nanos is 0 to 1 000 000.
291
+ val secondsSum: Long = checkedAdd(checkedAdd(t.seconds, seconds), nanoSum / L_NANOS_PER_SECOND )
292
+ Values .timestamp(secondsSum, (nanoSum % I_NANOS_PER_SECOND ).toInt())
293
+ }
294
+
295
+ private fun plus (t : Timestamp , seconds : Long ): Timestamp =
296
+ if (seconds == 0L ) t
297
+ else Values .timestamp(checkedAdd(t.seconds, seconds), t.nanos)
298
+
299
+ internal fun minus (t : Timestamp , seconds : Long , nanos : Long ): Timestamp =
300
+ if (nanos == 0L ) {
301
+ minus(t, seconds)
302
+ } else {
303
+ val nanoSum = t.nanos - nanos // Overflow not possible since nanos is 0 to 1 000 000.
304
+ val secondsSum: Long = checkedSubtract(t.seconds, checkedSubtract(seconds, nanoSum / L_NANOS_PER_SECOND ))
305
+ Values .timestamp(secondsSum, (nanoSum % I_NANOS_PER_SECOND ).toInt())
306
+ }
307
+
308
+ private fun minus (t : Timestamp , seconds : Long ): Timestamp =
309
+ if (seconds == 0L ) t
310
+ else Values .timestamp(checkedSubtract(t.seconds, seconds), t.nanos)
311
+
312
+
313
+ internal val evaluateTimestampAdd =
314
+ ternaryTimestampFunction { t: Timestamp , u: String , n: Long ->
315
+ EvaluateResult .timestamp(
316
+ when (u) {
317
+ " microsecond" -> plus(t, n / L_MICROS_PER_SECOND , (n % L_MICROS_PER_SECOND ) * 1000 )
318
+ " millisecond" -> plus(t, n / L_MILLIS_PER_SECOND , (n % L_MILLIS_PER_SECOND ) * 1000_000 )
319
+ " second" -> plus(t, n)
320
+ " minute" -> plus(t, checkedMultiply(n, 60 ))
321
+ " hour" -> plus(t, checkedMultiply(n, 3600 ))
322
+ " day" -> plus(t, checkedMultiply(n, 86400 ))
323
+ else -> return @ternaryTimestampFunction EvaluateResultError
324
+ }
325
+ )
326
+ }
327
+
328
+ internal val evaluateTimestampSub =
329
+ ternaryTimestampFunction { t: Timestamp , u: String , n: Long ->
330
+ EvaluateResult .timestamp(
331
+ when (u) {
332
+ " microsecond" -> minus(t, n / L_MICROS_PER_SECOND , (n % L_MICROS_PER_SECOND ) * 1000 )
333
+ " millisecond" -> minus(t, n / L_MILLIS_PER_SECOND , (n % L_MILLIS_PER_SECOND ) * 1000_000 )
334
+ " second" -> minus(t, n)
335
+ " minute" -> minus(t, checkedMultiply(n, 60 ))
336
+ " hour" -> minus(t, checkedMultiply(n, 3600 ))
337
+ " day" -> minus(t, checkedMultiply(n, 86400 ))
338
+ else -> return @ternaryTimestampFunction EvaluateResultError
339
+ }
340
+ )
341
+ }
279
342
280
343
internal val evaluateTimestampTrunc = notImplemented // TODO: Does not exist in expressions.kt yet.
281
344
@@ -284,39 +347,46 @@ internal val evaluateTimestampToUnixMicros = unaryFunction { t: Timestamp ->
284
347
if (t.seconds < Long .MIN_VALUE / 1_000_000 ) {
285
348
// To avoid overflow when very close to Long.MIN_VALUE, add 1 second, multiply, then subtract
286
349
// again.
287
- val micros = checkedMultiply(t.seconds + 1 , 1_000_000 )
288
- val adjustment = t.nanos.toLong() / 1_000 - 1_000_000
350
+ val micros = checkedMultiply(t.seconds + 1 , L_MICROS_PER_SECOND )
351
+ val adjustment = t.nanos.toLong() / L_MILLIS_PER_SECOND - L_MICROS_PER_SECOND
289
352
checkedAdd(micros, adjustment)
290
353
} else {
291
- val micros = checkedMultiply(t.seconds, 1_000_000 )
292
- checkedAdd(micros, t.nanos.toLong() / 1_000 )
354
+ val micros = checkedMultiply(t.seconds, L_MICROS_PER_SECOND )
355
+ checkedAdd(micros, t.nanos.toLong() / L_MILLIS_PER_SECOND )
293
356
}
294
357
)
295
358
}
296
359
297
360
internal val evaluateTimestampToUnixMillis = unaryFunction { t: Timestamp ->
298
361
EvaluateResult .long(
299
362
if (t.seconds < 0 && t.nanos > 0 ) {
300
- val millis = checkedMultiply(t.seconds + 1 , 1000 )
301
- val adjustment = t.nanos.toLong() / 1000_000 - 1000
363
+ val millis = checkedMultiply(t.seconds + 1 , L_MILLIS_PER_SECOND )
364
+ val adjustment = t.nanos.toLong() / L_MICROS_PER_SECOND - L_MILLIS_PER_SECOND
302
365
checkedAdd(millis, adjustment)
303
366
} else {
304
- val millis = checkedMultiply(t.seconds, 1000 )
305
- checkedAdd(millis, t.nanos.toLong() / 1000_000 )
367
+ val millis = checkedMultiply(t.seconds, L_MILLIS_PER_SECOND )
368
+ checkedAdd(millis, t.nanos.toLong() / L_MICROS_PER_SECOND )
306
369
}
307
370
)
308
371
}
309
372
310
373
internal val evaluateTimestampToUnixSeconds = unaryFunction { t: Timestamp ->
311
- if (t.nanos !in 0 until 1_000_000_000 ) EvaluateResultError else EvaluateResult .long(t.seconds)
374
+ if (t.nanos !in 0 until L_NANOS_PER_SECOND ) EvaluateResultError
375
+ else EvaluateResult .long(t.seconds)
312
376
}
313
377
314
378
internal val evaluateUnixMicrosToTimestamp = unaryFunction { micros: Long ->
315
- EvaluateResult .timestamp(Math .floorDiv(micros, 1000_000 ), Math .floorMod(micros, 1000_000 ))
379
+ EvaluateResult .timestamp(
380
+ Math .floorDiv(micros, L_MICROS_PER_SECOND ),
381
+ Math .floorMod(micros, I_MICROS_PER_SECOND )
382
+ )
316
383
}
317
384
318
385
internal val evaluateUnixMillisToTimestamp = unaryFunction { millis: Long ->
319
- EvaluateResult .timestamp(Math .floorDiv(millis, 1000 ), Math .floorMod(millis, 1000 ))
386
+ EvaluateResult .timestamp(
387
+ Math .floorDiv(millis, L_MILLIS_PER_SECOND ),
388
+ Math .floorMod(millis, I_MILLIS_PER_SECOND )
389
+ )
320
390
}
321
391
322
392
internal val evaluateUnixSecondsToTimestamp = unaryFunction { seconds: Long ->
@@ -457,6 +527,43 @@ private inline fun binaryFunction(crossinline function: (String, String) -> Eval
457
527
function
458
528
)
459
529
530
+ private inline fun ternaryTimestampFunction (
531
+ crossinline function : (Timestamp , String , Long ) -> EvaluateResult
532
+ ): EvaluateFunction = ternaryNullableValueFunction { timestamp: Value , unit: Value , number: Value ->
533
+ val t: Timestamp =
534
+ when (timestamp.valueTypeCase) {
535
+ Value .ValueTypeCase .NULL_VALUE -> return @ternaryNullableValueFunction EvaluateResult .NULL
536
+ Value .ValueTypeCase .TIMESTAMP_VALUE -> timestamp.timestampValue
537
+ else -> return @ternaryNullableValueFunction EvaluateResultError
538
+ }
539
+ val u: String =
540
+ if (unit.hasStringValue()) unit.stringValue
541
+ else return @ternaryNullableValueFunction EvaluateResultError
542
+ val n: Long =
543
+ when (number.valueTypeCase) {
544
+ Value .ValueTypeCase .NULL_VALUE -> return @ternaryNullableValueFunction EvaluateResult .NULL
545
+ Value .ValueTypeCase .INTEGER_VALUE -> number.integerValue
546
+ else -> return @ternaryNullableValueFunction EvaluateResultError
547
+ }
548
+ function(t, u, n)
549
+ }
550
+
551
+ private inline fun ternaryNullableValueFunction (
552
+ crossinline function : (Value , Value , Value ) -> EvaluateResult
553
+ ): EvaluateFunction = { params ->
554
+ if (params.size != 3 )
555
+ throw Assert .fail(" Function should have exactly 3 params, but %d were given." , params.size)
556
+ val p1 = params[0 ]
557
+ val p2 = params[1 ]
558
+ val p3 = params[2 ]
559
+ block@{ input: MutableDocument ->
560
+ val v1 = p1(input).value ? : return @block EvaluateResultError
561
+ val v2 = p2(input).value ? : return @block EvaluateResultError
562
+ val v3 = p3(input).value ? : return @block EvaluateResultError
563
+ catch { function(v1, v2, v3) }
564
+ }
565
+ }
566
+
460
567
private inline fun <T1 , T2 > binaryFunctionType (
461
568
valueTypeCase1 : Value .ValueTypeCase ,
462
569
crossinline valueExtractor1 : (Value ) -> T1 ,
0 commit comments