Skip to content

Commit c7605b6

Browse files
committed
Generic Stage and Refactor
1 parent 63cd54e commit c7605b6

File tree

9 files changed

+274
-121
lines changed

9 files changed

+274
-121
lines changed

firebase-firestore/src/androidTest/java/com/google/firebase/firestore/PipelineTest.java

Lines changed: 52 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,13 @@
3737
import static com.google.firebase.firestore.pipeline.Function.subtract;
3838
import static com.google.firebase.firestore.pipeline.Ordering.ascending;
3939
import static com.google.firebase.firestore.testutil.IntegrationTestUtil.waitFor;
40-
import static java.util.Map.entry;
4140

4241
import androidx.test.ext.junit.runners.AndroidJUnit4;
4342
import com.google.android.gms.tasks.Task;
4443
import com.google.common.collect.ImmutableList;
4544
import com.google.common.collect.ImmutableMap;
4645
import com.google.common.truth.Correspondence;
47-
import com.google.firebase.firestore.pipeline.Accumulator;
46+
import com.google.firebase.firestore.pipeline.AggregateExpr;
4847
import com.google.firebase.firestore.pipeline.AggregateStage;
4948
import com.google.firebase.firestore.pipeline.Constant;
5049
import com.google.firebase.firestore.pipeline.Field;
@@ -227,7 +226,7 @@ public void aggregateResultsCountAll() {
227226
firestore
228227
.pipeline()
229228
.collection(randomCol)
230-
.aggregate(Accumulator.countAll().as("count"))
229+
.aggregate(AggregateExpr.countAll().as("count"))
231230
.execute();
232231
assertThat(waitFor(execute).getResults())
233232
.comparingElementsUsing(DATA_CORRESPONDENCE)
@@ -243,8 +242,8 @@ public void aggregateResultsMany() {
243242
.collection(randomCol)
244243
.where(Function.eq("genre", "Science Fiction"))
245244
.aggregate(
246-
Accumulator.countAll().as("count"),
247-
Accumulator.avg("rating").as("avgRating"),
245+
AggregateExpr.countAll().as("count"),
246+
AggregateExpr.avg("rating").as("avgRating"),
248247
Field.of("rating").max().as("maxRating"))
249248
.execute();
250249
assertThat(waitFor(execute).getResults())
@@ -261,7 +260,7 @@ public void groupAndAccumulateResults() {
261260
.collection(randomCol)
262261
.where(lt(Field.of("published"), 1984))
263262
.aggregate(
264-
AggregateStage.withAccumulators(Accumulator.avg("rating").as("avgRating"))
263+
AggregateStage.withAccumulators(AggregateExpr.avg("rating").as("avgRating"))
265264
.withGroups("genre"))
266265
.where(gt("avgRating", 4.3))
267266
.sort(Field.of("avgRating").descending())
@@ -274,6 +273,28 @@ public void groupAndAccumulateResults() {
274273
mapOfEntries(entry("avgRating", 4.4), entry("genre", "Science Fiction")));
275274
}
276275

276+
@Test
277+
public void groupAndAccumulateResultsGeneric() {
278+
Task<PipelineSnapshot> execute =
279+
firestore
280+
.pipeline()
281+
.collection(randomCol)
282+
.genericStage("where", lt(Field.of("published"), 1984))
283+
.genericStage(
284+
"aggregate",
285+
ImmutableMap.of("avgRating", AggregateExpr.avg("rating")),
286+
ImmutableMap.of("genre", Field.of("genre")))
287+
.genericStage("where", gt("avgRating", 4.3))
288+
.genericStage("sort", Field.of("avgRating").descending())
289+
.execute();
290+
assertThat(waitFor(execute).getResults())
291+
.comparingElementsUsing(DATA_CORRESPONDENCE)
292+
.containsExactly(
293+
mapOfEntries(entry("avgRating", 4.7), entry("genre", "Fantasy")),
294+
mapOfEntries(entry("avgRating", 4.5), entry("genre", "Romance")),
295+
mapOfEntries(entry("avgRating", 4.4), entry("genre", "Science Fiction")));
296+
}
297+
277298
@Test
278299
@Ignore("Not supported yet")
279300
public void minAndMaxAccumulations() {
@@ -282,7 +303,7 @@ public void minAndMaxAccumulations() {
282303
.pipeline()
283304
.collection(randomCol)
284305
.aggregate(
285-
Accumulator.countAll().as("count"),
306+
AggregateExpr.countAll().as("count"),
286307
Field.of("rating").max().as("maxRating"),
287308
Field.of("published").min().as("minPublished"))
288309
.execute();
@@ -781,6 +802,30 @@ public void testMapGetWithFieldNameIncludingNotation() {
781802
entry("nested", null)));
782803
}
783804

805+
@Test
806+
public void testListEquals() {
807+
Task<PipelineSnapshot> execute =
808+
randomCol
809+
.pipeline()
810+
.where(eq("tags", ImmutableList.of("philosophy", "crime", "redemption")))
811+
.execute();
812+
assertThat(waitFor(execute).getResults())
813+
.comparingElementsUsing(ID_CORRESPONDENCE)
814+
.containsExactly("book6");
815+
}
816+
817+
@Test
818+
public void testMapEquals() {
819+
Task<PipelineSnapshot> execute =
820+
randomCol
821+
.pipeline()
822+
.where(eq("awards", ImmutableMap.of("nobel", true, "nebula", false)))
823+
.execute();
824+
assertThat(waitFor(execute).getResults())
825+
.comparingElementsUsing(ID_CORRESPONDENCE)
826+
.containsExactly("book3");
827+
}
828+
784829
static <T> Map.Entry<String, T> entry(String key, T value) {
785830
return new Map.Entry<String, T>() {
786831
private String k = key;

firebase-firestore/src/main/java/com/google/firebase/firestore/FirebaseFirestore.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import androidx.annotation.Keep;
2424
import androidx.annotation.NonNull;
2525
import androidx.annotation.Nullable;
26+
import androidx.annotation.RestrictTo;
2627
import androidx.annotation.VisibleForTesting;
2728
import com.google.android.gms.tasks.Task;
2829
import com.google.android.gms.tasks.TaskCompletionSource;
@@ -855,7 +856,9 @@ DatabaseId getDatabaseId() {
855856
return databaseId;
856857
}
857858

858-
UserDataReader getUserDataReader() {
859+
@NonNull
860+
@RestrictTo(RestrictTo.Scope.LIBRARY)
861+
public UserDataReader getUserDataReader() {
859862
return userDataReader;
860863
}
861864

firebase-firestore/src/main/java/com/google/firebase/firestore/Pipeline.kt

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,18 @@ import com.google.common.collect.FluentIterable
2020
import com.google.common.collect.ImmutableList
2121
import com.google.firebase.firestore.model.DocumentKey
2222
import com.google.firebase.firestore.model.SnapshotVersion
23-
import com.google.firebase.firestore.pipeline.AccumulatorWithAlias
2423
import com.google.firebase.firestore.pipeline.AddFieldsStage
2524
import com.google.firebase.firestore.pipeline.AggregateStage
25+
import com.google.firebase.firestore.pipeline.AggregateWithAlias
2626
import com.google.firebase.firestore.pipeline.BooleanExpr
2727
import com.google.firebase.firestore.pipeline.CollectionGroupSource
2828
import com.google.firebase.firestore.pipeline.CollectionSource
2929
import com.google.firebase.firestore.pipeline.DatabaseSource
3030
import com.google.firebase.firestore.pipeline.DistinctStage
3131
import com.google.firebase.firestore.pipeline.DocumentsSource
3232
import com.google.firebase.firestore.pipeline.Field
33+
import com.google.firebase.firestore.pipeline.GenericArg
34+
import com.google.firebase.firestore.pipeline.GenericStage
3335
import com.google.firebase.firestore.pipeline.LimitStage
3436
import com.google.firebase.firestore.pipeline.OffsetStage
3537
import com.google.firebase.firestore.pipeline.Ordering
@@ -84,9 +86,12 @@ internal constructor(
8486

8587
internal fun toPipelineProto(): com.google.firestore.v1.Pipeline =
8688
com.google.firestore.v1.Pipeline.newBuilder()
87-
.addAllStages(stages.map(Stage::toProtoStage))
89+
.addAllStages(stages.map { it.toProtoStage(firestore.userDataReader) })
8890
.build()
8991

92+
fun genericStage(name: String, vararg params: Any) =
93+
append(GenericStage(name, params.map(GenericArg::from)))
94+
9095
fun addFields(vararg fields: Selectable): Pipeline = append(AddFieldsStage(fields))
9196

9297
fun removeFields(vararg fields: Field): Pipeline = append(RemoveFieldsStage(fields))
@@ -118,7 +123,7 @@ internal constructor(
118123
fun distinct(vararg groups: Any): Pipeline =
119124
append(DistinctStage(groups.map(Selectable::toSelectable).toTypedArray()))
120125

121-
fun aggregate(vararg accumulators: AccumulatorWithAlias): Pipeline =
126+
fun aggregate(vararg accumulators: AggregateWithAlias): Pipeline =
122127
append(AggregateStage.withAccumulators(*accumulators))
123128

124129
fun aggregate(aggregateStage: AggregateStage): Pipeline = append(aggregateStage)

firebase-firestore/src/main/java/com/google/firebase/firestore/UserDataReader.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import androidx.annotation.Nullable;
2121
import androidx.annotation.RestrictTo;
22+
import com.google.common.base.Function;
2223
import com.google.firebase.firestore.FieldValue.ArrayRemoveFieldValue;
2324
import com.google.firebase.firestore.FieldValue.ArrayUnionFieldValue;
2425
import com.google.firebase.firestore.FieldValue.DeleteFieldValue;
@@ -36,6 +37,7 @@
3637
import com.google.firebase.firestore.model.mutation.FieldMask;
3738
import com.google.firebase.firestore.model.mutation.NumericIncrementTransformOperation;
3839
import com.google.firebase.firestore.model.mutation.ServerTimestampOperation;
40+
import com.google.firebase.firestore.pipeline.Expr;
3941
import com.google.firebase.firestore.util.Assert;
4042
import com.google.firebase.firestore.util.CustomClassMapper;
4143
import com.google.firebase.firestore.util.Util;
@@ -389,6 +391,12 @@ public Value parseScalarValue(Object input, ParseContext context) {
389391
return Values.NULL_VALUE;
390392
} else if (input.getClass().isArray()) {
391393
throw context.createError("Arrays are not supported; use a List instead");
394+
} else if (input instanceof DocumentReference) {
395+
DocumentReference ref = (DocumentReference) input;
396+
validateDocumentReference(ref, context::createError);
397+
return Values.encodeValue(ref);
398+
} else if (input instanceof Expr) {
399+
throw context.createError("Pipeline expressions are not supported user objects");
392400
} else {
393401
try {
394402
return Values.encodeAnyValue(input);
@@ -398,6 +406,20 @@ public Value parseScalarValue(Object input, ParseContext context) {
398406
}
399407
}
400408

409+
public void validateDocumentReference(
410+
DocumentReference ref, Function<String, RuntimeException> createError) {
411+
DatabaseId otherDb = ref.getFirestore().getDatabaseId();
412+
if (!otherDb.equals(databaseId)) {
413+
throw createError.apply(
414+
String.format(
415+
"Document reference is for database %s/%s but should be for database %s/%s",
416+
otherDb.getProjectId(),
417+
otherDb.getDatabaseId(),
418+
databaseId.getProjectId(),
419+
databaseId.getDatabaseId()));
420+
}
421+
}
422+
401423
private List<Value> parseArrayTransformElements(List<Object> elements) {
402424
ParseAccumulator accumulator = new ParseAccumulator(UserData.Source.Argument);
403425

firebase-firestore/src/main/java/com/google/firebase/firestore/model/Values.kt

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -609,9 +609,7 @@ internal object Values {
609609

610610
return Value.newBuilder()
611611
.setTimestampValue(
612-
com.google.protobuf.Timestamp.newBuilder()
613-
.setSeconds(timestamp.seconds)
614-
.setNanos(truncatedNanoseconds)
612+
Timestamp.newBuilder().setSeconds(timestamp.seconds).setNanos(truncatedNanoseconds)
615613
)
616614
.build()
617615
}
@@ -665,6 +663,11 @@ internal object Values {
665663
return Value.newBuilder().setMapValue(MapValue.newBuilder().putAllFields(map)).build()
666664
}
667665

666+
@JvmStatic
667+
fun encodeValue(values: Iterable<Value>): Value {
668+
return Value.newBuilder().setArrayValue(ArrayValue.newBuilder().addAllValues(values)).build()
669+
}
670+
668671
@JvmStatic
669672
fun encodeAnyValue(value: Any?): Value {
670673
return when (value) {
@@ -676,7 +679,6 @@ internal object Values {
676679
is Boolean -> encodeValue(value)
677680
is GeoPoint -> encodeValue(value)
678681
is Blob -> encodeValue(value)
679-
is DocumentReference -> encodeValue(value)
680682
is VectorValue -> encodeValue(value)
681683
else -> throw IllegalArgumentException("Unexpected type: $value")
682684
}

firebase-firestore/src/main/java/com/google/firebase/firestore/pipeline/Constant.kt

Lines changed: 24 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -18,76 +18,70 @@ import com.google.firebase.Timestamp
1818
import com.google.firebase.firestore.Blob
1919
import com.google.firebase.firestore.DocumentReference
2020
import com.google.firebase.firestore.GeoPoint
21+
import com.google.firebase.firestore.UserDataReader
2122
import com.google.firebase.firestore.VectorValue
2223
import com.google.firebase.firestore.model.Values
2324
import com.google.firebase.firestore.model.Values.encodeValue
2425
import com.google.firestore.v1.Value
2526
import java.util.Date
2627

27-
class Constant internal constructor(val value: Value) : Expr() {
28+
abstract class Constant internal constructor() : Expr() {
29+
30+
private class ValueConstant(val value: Value) : Constant() {
31+
override fun toProto(userDataReader: UserDataReader): Value = value
32+
}
2833

2934
companion object {
30-
internal val NULL = Constant(Values.NULL_VALUE)
31-
32-
fun of(value: Any): Constant {
33-
return when (value) {
34-
is String -> of(value)
35-
is Number -> of(value)
36-
is Date -> of(value)
37-
is Timestamp -> of(value)
38-
is Boolean -> of(value)
39-
is GeoPoint -> of(value)
40-
is Blob -> of(value)
41-
is DocumentReference -> of(value)
42-
is Value -> of(value)
43-
is VectorValue -> of(value)
44-
else -> throw IllegalArgumentException("Unknown type: $value")
45-
}
46-
}
35+
internal val NULL: Constant = ValueConstant(Values.NULL_VALUE)
4736

4837
@JvmStatic
4938
fun of(value: String): Constant {
50-
return Constant(encodeValue(value))
39+
return ValueConstant(encodeValue(value))
5140
}
5241

5342
@JvmStatic
5443
fun of(value: Number): Constant {
55-
return Constant(encodeValue(value))
44+
return ValueConstant(encodeValue(value))
5645
}
5746

5847
@JvmStatic
5948
fun of(value: Date): Constant {
60-
return Constant(encodeValue(value))
49+
return ValueConstant(encodeValue(value))
6150
}
6251

6352
@JvmStatic
6453
fun of(value: Timestamp): Constant {
65-
return Constant(encodeValue(value))
54+
return ValueConstant(encodeValue(value))
6655
}
6756

6857
@JvmStatic
6958
fun of(value: Boolean): Constant {
70-
return Constant(encodeValue(value))
59+
return ValueConstant(encodeValue(value))
7160
}
7261

7362
@JvmStatic
7463
fun of(value: GeoPoint): Constant {
75-
return Constant(encodeValue(value))
64+
return ValueConstant(encodeValue(value))
7665
}
7766

7867
@JvmStatic
7968
fun of(value: Blob): Constant {
80-
return Constant(encodeValue(value))
69+
return ValueConstant(encodeValue(value))
8170
}
8271

8372
@JvmStatic
84-
fun of(value: DocumentReference): Constant {
85-
return Constant(encodeValue(value))
73+
fun of(ref: DocumentReference): Constant {
74+
return object : Constant() {
75+
override fun toProto(userDataReader: UserDataReader): Value {
76+
userDataReader.validateDocumentReference(ref, ::IllegalArgumentException)
77+
return encodeValue(ref)
78+
}
79+
}
8680
}
8781

8882
@JvmStatic
8983
fun of(value: VectorValue): Constant {
90-
return Constant(encodeValue(value))
84+
return ValueConstant(encodeValue(value))
9185
}
9286

9387
@JvmStatic
@@ -97,16 +91,12 @@ class Constant internal constructor(val value: Value) : Expr() {
9791

9892
@JvmStatic
9993
fun vector(value: DoubleArray): Constant {
100-
return Constant(Values.encodeVectorValue(value))
94+
return ValueConstant(Values.encodeVectorValue(value))
10195
}
10296

10397
@JvmStatic
10498
fun vector(value: VectorValue): Constant {
105-
return Constant(encodeValue(value))
99+
return ValueConstant(encodeValue(value))
106100
}
107101
}
108-
109-
override fun toProto(): Value {
110-
return value
111-
}
112102
}

0 commit comments

Comments
 (0)