1717 */
1818package org .apache .beam .runners .dataflow .worker .windmill .state ;
1919
20+ import com .google .auto .value .AutoValue ;
2021import java .io .Closeable ;
22+ import java .util .HashMap ;
2123import java .util .Optional ;
2224import javax .annotation .Nullable ;
2325import org .apache .beam .runners .core .StateNamespace ;
24- import org .apache .beam .runners .core .StateTable ;
2526import org .apache .beam .runners .core .StateTag ;
2627import org .apache .beam .runners .core .StateTags ;
2728import org .apache .beam .runners .dataflow .worker .util .common .worker .InternedByteString ;
2829import org .apache .beam .runners .dataflow .worker .windmill .state .WindmillStateCache .ForKeyAndFamily ;
2930import org .apache .beam .sdk .coders .BooleanCoder ;
3031import org .apache .beam .sdk .coders .Coder ;
3132import org .apache .beam .sdk .state .*;
32- import org .apache .beam .sdk .transforms .Combine ;
33- import org .apache .beam .sdk .transforms .CombineWithContext ;
33+ import org .apache .beam .sdk .transforms .Combine . CombineFn ;
34+ import org .apache .beam .sdk .transforms .CombineWithContext . CombineFnWithContext ;
3435import org .apache .beam .sdk .transforms .windowing .TimestampCombiner ;
3536import org .apache .beam .sdk .util .CombineFnUtil ;
3637import org .apache .beam .vendor .guava .v32_1_2_jre .com .google .common .base .Preconditions ;
3738import org .apache .beam .vendor .guava .v32_1_2_jre .com .google .common .base .Supplier ;
3839
39- final class CachingStateTable extends StateTable {
40+ final class CachingStateTable {
4041
42+ private final HashMap <StateTableKey , WindmillState > stateTable ;
4143 private final String stateFamily ;
4244 private final WindmillStateReader reader ;
4345 private final WindmillStateCache .ForKeyAndFamily cache ;
4446 private final boolean isSystemTable ;
4547 private final Supplier <Closeable > scopedReadStateSupplier ;
46- private final @ Nullable StateTable derivedStateTable ;
48+ private final @ Nullable CachingStateTable derivedStateTable ;
4749 private final boolean isNewKey ;
4850 private final boolean mapStateViaMultimapState ;
4951 private final WindmillStateTagUtil windmillStateTagUtil ;
5052
5153 private CachingStateTable (Builder builder ) {
54+ this .stateTable = new HashMap <>();
5255 this .stateFamily = builder .stateFamily ;
5356 this .reader = builder .reader ;
5457 this .cache = builder .cache ;
@@ -65,20 +68,45 @@ private CachingStateTable(Builder builder) {
6568 }
6669 }
6770
68- static CachingStateTable . Builder builder (
71+ static Builder builder (
6972 String stateFamily ,
7073 WindmillStateReader reader ,
7174 ForKeyAndFamily cache ,
7275 boolean isNewKey ,
7376 Supplier <Closeable > scopedReadStateSupplier ,
7477 WindmillStateTagUtil windmillStateTagUtil ) {
75- return new CachingStateTable . Builder (
78+ return new Builder (
7679 stateFamily , reader , cache , scopedReadStateSupplier , isNewKey , windmillStateTagUtil );
7780 }
7881
79- @ Override
82+ /**
83+ * Gets the {@link State} in the specified {@link StateNamespace} with the specified {@link
84+ * StateTag}, binding it using the {@link #binderForNamespace} if it is not already present in
85+ * this {@link CachingStateTable}.
86+ */
87+ public <StateT extends State > StateT get (
88+ StateNamespace namespace , StateTag <StateT > tag , StateContext <?> c ) {
89+
90+ StateTableKey stateTableKey = StateTableKey .create (namespace , tag );
91+ @ SuppressWarnings ("unchecked" )
92+ StateT storage =
93+ (StateT )
94+ stateTable .computeIfAbsent (
95+ stateTableKey ,
96+ unusedKey -> (WindmillState ) tag .bind (binderForNamespace (namespace , c )));
97+ return storage ;
98+ }
99+
100+ public void clear () {
101+ stateTable .clear ();
102+ }
103+
104+ public Iterable <WindmillState > values () {
105+ return stateTable .values ();
106+ }
107+
80108 @ SuppressWarnings ("deprecation" )
81- protected StateTag .StateBinder binderForNamespace (StateNamespace namespace , StateContext <?> c ) {
109+ private StateTag .StateBinder binderForNamespace (StateNamespace namespace , StateContext <?> c ) {
82110 // Look up state objects in the cache or create new ones if not found. The state will
83111 // be added to the cache in persist().
84112 return new StateTag .StateBinder () {
@@ -190,7 +218,7 @@ public WatermarkHoldState bindWatermark(
190218 public <InputT , AccumT , OutputT > CombiningState <InputT , AccumT , OutputT > bindCombiningValue (
191219 StateTag <CombiningState <InputT , AccumT , OutputT >> address ,
192220 Coder <AccumT > accumCoder ,
193- Combine . CombineFn <InputT , AccumT , OutputT > combineFn ) {
221+ CombineFn <InputT , AccumT , OutputT > combineFn ) {
194222 StateTag <CombiningState <InputT , AccumT , OutputT >> addressOrInternalTag =
195223 addressOrInternalTag (address );
196224
@@ -214,7 +242,7 @@ public <InputT, AccumT, OutputT> CombiningState<InputT, AccumT, OutputT> bindCom
214242 CombiningState <InputT , AccumT , OutputT > bindCombiningValueWithContext (
215243 StateTag <CombiningState <InputT , AccumT , OutputT >> address ,
216244 Coder <AccumT > accumCoder ,
217- CombineWithContext . CombineFnWithContext <InputT , AccumT , OutputT > combineFn ) {
245+ CombineFnWithContext <InputT , AccumT , OutputT > combineFn ) {
218246 return bindCombiningValue (
219247 addressOrInternalTag (address ), accumCoder , CombineFnUtil .bindContext (combineFn , c ));
220248 }
@@ -239,6 +267,21 @@ private <T extends State> StateTag<T> addressOrInternalTag(StateTag<T> address)
239267 };
240268 }
241269
270+ @ AutoValue
271+ abstract static class StateTableKey {
272+
273+ public abstract StateNamespace getStateNamespace ();
274+
275+ public abstract String getId ();
276+
277+ public static StateTableKey create (StateNamespace namespace , StateTag <?> stateTag ) {
278+ // TODO(https://github.com/apache/beam/issues/36753): stateTag.getId() returns only the
279+ // string tag without system/user prefix. This could cause a collision between system and
280+ // user tag with the same id. Consider adding the prefix to state table key.
281+ return new AutoValue_CachingStateTable_StateTableKey (namespace , stateTag .getId ());
282+ }
283+ }
284+
242285 static class Builder {
243286
244287 private final String stateFamily ;
@@ -248,7 +291,7 @@ static class Builder {
248291 private final boolean isNewKey ;
249292 private final WindmillStateTagUtil windmillStateTagUtil ;
250293 private boolean isSystemTable ;
251- private @ Nullable StateTable derivedStateTable ;
294+ private @ Nullable CachingStateTable derivedStateTable ;
252295 private boolean mapStateViaMultimapState = false ;
253296
254297 private Builder (
@@ -268,7 +311,7 @@ private Builder(
268311 this .windmillStateTagUtil = windmillStateTagUtil ;
269312 }
270313
271- Builder withDerivedState (StateTable derivedStateTable ) {
314+ Builder withDerivedState (CachingStateTable derivedStateTable ) {
272315 this .isSystemTable = false ;
273316 this .derivedStateTable = derivedStateTable ;
274317 return this ;
0 commit comments