2424import org .apache .beam .runners .core .StateTable ;
2525import org .apache .beam .runners .core .StateTag ;
2626import org .apache .beam .runners .core .StateTags ;
27+ import org .apache .beam .runners .dataflow .worker .util .common .worker .InternedByteString ;
2728import org .apache .beam .runners .dataflow .worker .windmill .state .WindmillStateCache .ForKeyAndFamily ;
2829import org .apache .beam .sdk .coders .BooleanCoder ;
2930import org .apache .beam .sdk .coders .Coder ;
3637import org .apache .beam .vendor .guava .v32_1_2_jre .com .google .common .base .Supplier ;
3738
3839final class CachingStateTable extends StateTable {
40+
3941 private final String stateFamily ;
4042 private final WindmillStateReader reader ;
4143 private final WindmillStateCache .ForKeyAndFamily cache ;
@@ -84,23 +86,14 @@ protected StateTag.StateBinder binderForNamespace(StateNamespace namespace, Stat
8486 public <T > BagState <T > bindBag (StateTag <BagState <T >> address , Coder <T > elemCoder ) {
8587 StateTag <BagState <T >> resolvedAddress =
8688 isSystemTable ? StateTags .makeSystemTagInternal (address ) : address ;
89+ InternedByteString encodedKey = windmillStateTagUtil .encodeKey (namespace , resolvedAddress );
8790
88- WindmillBag <T > result =
89- cache
90- .get (namespace , resolvedAddress )
91- .map (bagState -> (WindmillBag <T >) bagState )
92- .orElseGet (
93- () ->
94- new WindmillBag <>(
95- namespace ,
96- resolvedAddress ,
97- stateFamily ,
98- elemCoder ,
99- isNewKey ,
100- windmillStateTagUtil ));
101-
102- result .initializeForWorkItem (reader , scopedReadStateSupplier );
103- return result ;
91+ @ Nullable WindmillBag <T > bag = (WindmillBag <T >) cache .get (namespace , encodedKey );
92+ if (bag == null ) {
93+ bag = new WindmillBag <>(namespace , encodedKey , stateFamily , elemCoder , isNewKey );
94+ }
95+ bag .initializeForWorkItem (reader , scopedReadStateSupplier );
96+ return bag ;
10497 }
10598
10699 @ Override
@@ -123,20 +116,13 @@ public <KeyT, ValueT> AbstractWindmillMap<KeyT, ValueT> bindMap(
123116 new WindmillMapViaMultimap <>(
124117 bindMultimap (internalMultimapAddress , keyCoder , valueCoder ));
125118 } else {
126- result =
127- cache
128- .get (namespace , spec )
129- .map (mapState -> (AbstractWindmillMap <KeyT , ValueT >) mapState )
130- .orElseGet (
131- () ->
132- new WindmillMap <>(
133- namespace ,
134- spec ,
135- stateFamily ,
136- keyCoder ,
137- valueCoder ,
138- isNewKey ,
139- windmillStateTagUtil ));
119+ InternedByteString encodedKey = windmillStateTagUtil .encodeKey (namespace , spec );
120+ result = (AbstractWindmillMap <KeyT , ValueT >) cache .get (namespace , encodedKey );
121+ if (result == null ) {
122+ result =
123+ new WindmillMap <>(
124+ namespace , encodedKey , stateFamily , keyCoder , valueCoder , isNewKey );
125+ }
140126 }
141127 result .initializeForWorkItem (reader , scopedReadStateSupplier );
142128 return result ;
@@ -147,20 +133,14 @@ public <KeyT, ValueT> WindmillMultimap<KeyT, ValueT> bindMultimap(
147133 StateTag <MultimapState <KeyT , ValueT >> spec ,
148134 Coder <KeyT > keyCoder ,
149135 Coder <ValueT > valueCoder ) {
136+ InternedByteString encodedKey = windmillStateTagUtil .encodeKey (namespace , spec );
150137 WindmillMultimap <KeyT , ValueT > result =
151- cache
152- .get (namespace , spec )
153- .map (multimapState -> (WindmillMultimap <KeyT , ValueT >) multimapState )
154- .orElseGet (
155- () ->
156- new WindmillMultimap <>(
157- namespace ,
158- spec ,
159- stateFamily ,
160- keyCoder ,
161- valueCoder ,
162- isNewKey ,
163- windmillStateTagUtil ));
138+ (WindmillMultimap <KeyT , ValueT >) cache .get (namespace , encodedKey );
139+ if (result == null ) {
140+ result =
141+ new WindmillMultimap <>(
142+ namespace , encodedKey , stateFamily , keyCoder , valueCoder , isNewKey );
143+ }
164144 result .initializeForWorkItem (reader , scopedReadStateSupplier );
165145 return result ;
166146 }
@@ -169,21 +149,21 @@ public <KeyT, ValueT> WindmillMultimap<KeyT, ValueT> bindMultimap(
169149 public <T > OrderedListState <T > bindOrderedList (
170150 StateTag <OrderedListState <T >> spec , Coder <T > elemCoder ) {
171151 StateTag <OrderedListState <T >> specOrInternalTag = addressOrInternalTag (spec );
152+ InternedByteString encodedKey =
153+ windmillStateTagUtil .encodeKey (namespace , specOrInternalTag );
172154
173- WindmillOrderedList <T > result =
174- cache
175- .get (namespace , specOrInternalTag )
176- .map (orderedList -> (WindmillOrderedList <T >) orderedList )
177- .orElseGet (
178- () ->
179- new WindmillOrderedList <>(
180- Optional .ofNullable (derivedStateTable ).orElse (CachingStateTable .this ),
181- namespace ,
182- specOrInternalTag ,
183- stateFamily ,
184- elemCoder ,
185- isNewKey ,
186- windmillStateTagUtil ));
155+ WindmillOrderedList <T > result = (WindmillOrderedList <T >) cache .get (namespace , encodedKey );
156+ if (result == null ) {
157+ result =
158+ new WindmillOrderedList <>(
159+ Optional .ofNullable (derivedStateTable ).orElse (CachingStateTable .this ),
160+ namespace ,
161+ encodedKey ,
162+ specOrInternalTag ,
163+ stateFamily ,
164+ elemCoder ,
165+ isNewKey );
166+ }
187167
188168 result .initializeForWorkItem (reader , scopedReadStateSupplier );
189169 return result ;
@@ -193,21 +173,15 @@ public <T> OrderedListState<T> bindOrderedList(
193173 public WatermarkHoldState bindWatermark (
194174 StateTag <WatermarkHoldState > address , TimestampCombiner timestampCombiner ) {
195175 StateTag <WatermarkHoldState > addressOrInternalTag = addressOrInternalTag (address );
176+ InternedByteString encodedKey =
177+ windmillStateTagUtil .encodeKey (namespace , addressOrInternalTag );
196178
197- WindmillWatermarkHold result =
198- cache
199- .get (namespace , addressOrInternalTag )
200- .map (watermarkHold -> (WindmillWatermarkHold ) watermarkHold )
201- .orElseGet (
202- () ->
203- new WindmillWatermarkHold (
204- namespace ,
205- address ,
206- stateFamily ,
207- timestampCombiner ,
208- isNewKey ,
209- windmillStateTagUtil ));
210-
179+ WindmillWatermarkHold result = (WindmillWatermarkHold ) cache .get (namespace , encodedKey );
180+ if (result == null ) {
181+ result =
182+ new WindmillWatermarkHold (
183+ namespace , encodedKey , stateFamily , timestampCombiner , isNewKey );
184+ }
211185 result .initializeForWorkItem (reader , scopedReadStateSupplier );
212186 return result ;
213187 }
@@ -248,21 +222,13 @@ CombiningState<InputT, AccumT, OutputT> bindCombiningValueWithContext(
248222 @ Override
249223 public <T > ValueState <T > bindValue (StateTag <ValueState <T >> address , Coder <T > coder ) {
250224 StateTag <ValueState <T >> addressOrInternalTag = addressOrInternalTag (address );
225+ InternedByteString encodedKey =
226+ windmillStateTagUtil .encodeKey (namespace , addressOrInternalTag );
251227
252- WindmillValue <T > result =
253- cache
254- .get (namespace , addressOrInternalTag )
255- .map (value -> (WindmillValue <T >) value )
256- .orElseGet (
257- () ->
258- new WindmillValue <>(
259- namespace ,
260- addressOrInternalTag ,
261- stateFamily ,
262- coder ,
263- isNewKey ,
264- windmillStateTagUtil ));
265-
228+ WindmillValue <T > result = (WindmillValue <T >) cache .get (namespace , encodedKey );
229+ if (result == null ) {
230+ result = new WindmillValue <>(namespace , encodedKey , stateFamily , coder , isNewKey );
231+ }
266232 result .initializeForWorkItem (reader , scopedReadStateSupplier );
267233 return result ;
268234 }
@@ -274,6 +240,7 @@ private <T extends State> StateTag<T> addressOrInternalTag(StateTag<T> address)
274240 }
275241
276242 static class Builder {
243+
277244 private final String stateFamily ;
278245 private final WindmillStateReader reader ;
279246 private final WindmillStateCache .ForKeyAndFamily cache ;
0 commit comments