Skip to content

Commit 17dc248

Browse files
committed
add transaction support
1 parent dda3c42 commit 17dc248

File tree

5 files changed

+104
-27
lines changed

5 files changed

+104
-27
lines changed

google-cloud-firestore/src/main/java/com/google/cloud/firestore/Pipeline.java

Lines changed: 50 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,14 @@
5252
import com.google.firestore.v1.ExecutePipelineRequest;
5353
import com.google.firestore.v1.ExecutePipelineResponse;
5454
import com.google.firestore.v1.StructuredPipeline;
55+
import com.google.protobuf.ByteString;
5556
import io.opencensus.trace.AttributeValue;
5657
import io.opencensus.trace.Tracing;
5758
import java.util.ArrayList;
5859
import java.util.List;
5960
import java.util.logging.Level;
6061
import java.util.logging.Logger;
62+
import javax.annotation.Nullable;
6163

6264
/**
6365
* The Pipeline class provides a flexible and expressive framework for building complex data
@@ -597,29 +599,7 @@ public Pipeline genericStage(String name, List<Object> params) {
597599
*/
598600
@BetaApi
599601
public ApiFuture<List<PipelineResult>> execute() {
600-
SettableApiFuture<List<PipelineResult>> futureResult = SettableApiFuture.create();
601-
602-
execute( // Assuming you have this method
603-
new PipelineResultObserver() {
604-
final List<PipelineResult> results = new ArrayList<>();
605-
606-
@Override
607-
public void onCompleted() {
608-
futureResult.set(results);
609-
}
610-
611-
@Override
612-
public void onNext(PipelineResult result) {
613-
results.add(result);
614-
}
615-
616-
@Override
617-
public void onError(Throwable t) {
618-
futureResult.setException(t);
619-
}
620-
});
621-
622-
return futureResult;
602+
return execute(null, null);
623603
}
624604

625605
/**
@@ -669,14 +649,57 @@ public void onError(Throwable t) {
669649
*/
670650
@BetaApi
671651
public void execute(ApiStreamObserver<PipelineResult> observer) {
672-
ExecutePipelineRequest request =
652+
executeInternal(null, null, observer);
653+
}
654+
655+
ApiFuture<List<PipelineResult>> execute(
656+
@Nullable final ByteString transactionId, @Nullable com.google.protobuf.Timestamp readTime) {
657+
SettableApiFuture<List<PipelineResult>> futureResult = SettableApiFuture.create();
658+
659+
executeInternal(
660+
transactionId,
661+
readTime,
662+
new PipelineResultObserver() {
663+
final List<PipelineResult> results = new ArrayList<>();
664+
665+
@Override
666+
public void onCompleted() {
667+
futureResult.set(results);
668+
}
669+
670+
@Override
671+
public void onNext(PipelineResult result) {
672+
results.add(result);
673+
}
674+
675+
@Override
676+
public void onError(Throwable t) {
677+
futureResult.setException(t);
678+
}
679+
});
680+
681+
return futureResult;
682+
}
683+
684+
void executeInternal(
685+
@Nullable final ByteString transactionId,
686+
@Nullable com.google.protobuf.Timestamp readTime,
687+
ApiStreamObserver<PipelineResult> observer) {
688+
ExecutePipelineRequest.Builder request =
673689
ExecutePipelineRequest.newBuilder()
674690
.setDatabase(rpcContext.getDatabaseName())
675-
.setStructuredPipeline(StructuredPipeline.newBuilder().setPipeline(toProto()).build())
676-
.build();
691+
.setStructuredPipeline(StructuredPipeline.newBuilder().setPipeline(toProto()).build());
692+
693+
if (transactionId != null) {
694+
request.setTransaction(transactionId);
695+
}
696+
697+
if (readTime != null) {
698+
request.setReadTime(readTime);
699+
}
677700

678701
pipelineInternalStream(
679-
request,
702+
request.build(),
680703
new PipelineResultObserver() {
681704
@Override
682705
public void onCompleted() {

google-cloud-firestore/src/main/java/com/google/cloud/firestore/ReadTimeTransaction.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,14 @@ public ApiFuture<AggregateQuerySnapshot> get(@Nonnull AggregateQuery query) {
125125
}
126126
}
127127

128+
@Nonnull
129+
@Override
130+
public ApiFuture<List<PipelineResult>> execute(@Nonnull Pipeline pipeline) {
131+
try (TraceUtil.Scope ignored = transactionTraceContext.makeCurrent()) {
132+
return pipeline.execute(null, readTime);
133+
}
134+
}
135+
128136
@Nonnull
129137
@Override
130138
public Transaction create(

google-cloud-firestore/src/main/java/com/google/cloud/firestore/ServerSideTransaction.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,4 +260,12 @@ public ApiFuture<AggregateQuerySnapshot> get(@Nonnull AggregateQuery query) {
260260
return query.get(transactionId, null);
261261
}
262262
}
263+
264+
@Nonnull
265+
@Override
266+
public ApiFuture<List<PipelineResult>> execute(@Nonnull Pipeline pipeline) {
267+
try (TraceUtil.Scope ignored = transactionTraceContext.makeCurrent()) {
268+
return pipeline.execute(transactionId, null);
269+
}
270+
}
263271
}

google-cloud-firestore/src/main/java/com/google/cloud/firestore/Transaction.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package com.google.cloud.firestore;
1818

1919
import com.google.api.core.ApiFuture;
20+
import com.google.api.core.BetaApi;
2021
import com.google.api.core.InternalExtensionOnly;
2122
import com.google.cloud.firestore.telemetry.TraceUtil;
2223
import com.google.cloud.firestore.telemetry.TraceUtil.Context;
@@ -134,4 +135,9 @@ public abstract ApiFuture<List<DocumentSnapshot>> getAll(
134135
*/
135136
@Nonnull
136137
public abstract ApiFuture<AggregateQuerySnapshot> get(@Nonnull AggregateQuery query);
138+
139+
/** @return The result of the aggregation. */
140+
@Nonnull
141+
@BetaApi
142+
public abstract ApiFuture<List<PipelineResult>> execute(@Nonnull Pipeline pipeline);
137143
}

google-cloud-firestore/src/test/java/com/google/cloud/firestore/it/ITPipelineTest.java

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343

4444
import com.google.cloud.firestore.CollectionReference;
4545
import com.google.cloud.firestore.LocalFirestoreHelper;
46+
import com.google.cloud.firestore.Pipeline;
4647
import com.google.cloud.firestore.PipelineResult;
4748
import com.google.cloud.firestore.pipeline.expressions.Constant;
4849
import com.google.cloud.firestore.pipeline.expressions.Field;
@@ -1013,4 +1014,35 @@ public void testNestedFields() throws Exception {
10131014
map("title", "The Hitchhiker's Guide to the Galaxy", "awards.hugo", true),
10141015
map("title", "Dune", "awards.hugo", true)));
10151016
}
1017+
1018+
@Test
1019+
public void testPipelineInTransactions() throws Exception {
1020+
Pipeline pipeline =
1021+
collection
1022+
.pipeline()
1023+
.where(eq("awards.hugo", true))
1024+
.select("title", "awards.hugo", Field.DOCUMENT_ID);
1025+
1026+
firestore
1027+
.runTransaction(
1028+
transaction -> {
1029+
List<PipelineResult> results = transaction.execute(pipeline).get();
1030+
1031+
assertThat(data(results))
1032+
.isEqualTo(
1033+
Lists.newArrayList(
1034+
map("title", "The Hitchhiker's Guide to the Galaxy", "awards.hugo", true),
1035+
map("title", "Dune", "awards.hugo", true)));
1036+
1037+
transaction.update(collection.document("book1"), map("foo", "bar"));
1038+
1039+
return "done";
1040+
})
1041+
.get();
1042+
1043+
List<PipelineResult> result =
1044+
collection.pipeline().where(eq("foo", "bar")).select("title").execute().get();
1045+
assertThat(data(result))
1046+
.isEqualTo(Lists.newArrayList(map("title", "The Hitchhiker's Guide to the Galaxy")));
1047+
}
10161048
}

0 commit comments

Comments
 (0)