@@ -13,6 +13,7 @@ module BatchWriteItemTransform {
1313 import Seq
1414 import SortedSets
1515 import Util = DynamoDbEncryptionUtil
16+ import Types = AwsCryptographyDbEncryptionSdkDynamoDbTypes
1617
1718 method {:vcs_split_on_every_assert} Input (config: Config , input: BatchWriteItemInputTransformInput )
1819 returns (output: Result< BatchWriteItemInputTransformOutput, Error> )
@@ -80,11 +81,86 @@ module BatchWriteItemTransform {
8081 return Success (BatchWriteItemInputTransformOutput(transformedInput := input.sdkInput.(RequestItems := result)));
8182 }
8283
84+ method GetOrigItem (
85+ tableConfig : ValidTableConfig ,
86+ srcRequests : DDB .WriteRequests,
87+ itemReq : DDB .WriteRequest
88+ ) returns (ret : Result< DDB. WriteRequest, Error> )
89+ requires itemReq. PutRequest. Some?
90+ ensures ret. Success? ==> ret. value. PutRequest. Some?
91+ {
92+ var partKey := tableConfig. partitionKeyName;
93+ var sortKey := tableConfig. sortKeyName;
94+ var item := itemReq. PutRequest. value. Item;
95+ :- Need (partKey in item, E("Required partition key not in unprocessed item"));
96+ :- Need (sortKey.None? || sortKey.value in item, E("Required sort key not in unprocessed item"));
97+ for i := 0 to |srcRequests| {
98+ if srcRequests[i]. PutRequest. Some? {
99+ var req := srcRequests[i]. PutRequest. value. Item;
100+ if partKey in req && req[partKey] == item[partKey] {
101+ if sortKey. None? || (sortKey. value in req && req[sortKey. value] == item[sortKey. value]) {
102+ return Success (srcRequests[i]);
103+ }
104+ }
105+ }
106+ }
107+ return Failure (E("Item in UnprocessedItems not found in original request."));
108+ }
109+
83110 method Output (config: Config , input: BatchWriteItemOutputTransformInput )
84111 returns (output: Result< BatchWriteItemOutputTransformOutput, Error> )
85- ensures output. Success? && output. value. transformedOutput == input. sdkOutput
86112 {
87- return Success (BatchWriteItemOutputTransformOutput(transformedOutput := input.sdkOutput));
113+ if input. sdkOutput. UnprocessedItems. None? {
114+ return Success (BatchWriteItemOutputTransformOutput(transformedOutput := input.sdkOutput));
115+ }
116+
117+ var srcItems := input. originalInput. RequestItems;
118+ var unprocessed := input. sdkOutput. UnprocessedItems. value;
119+ var tableNames := unprocessed. Keys;
120+ var result : map < DDB. TableArn, DDB. WriteRequests> := map [];
121+ var tableNamesSeq := SortedSets. ComputeSetToSequence (tableNames);
122+ ghost var tableNamesSet' := tableNames;
123+ var i := 0;
124+ while i < |tableNamesSeq|
125+ invariant Seq. HasNoDuplicates (tableNamesSeq)
126+ invariant forall j | i <= j < |tableNamesSeq| :: tableNamesSeq[j] in tableNamesSet'
127+ invariant |tableNamesSet'| == |tableNamesSeq| - i
128+ invariant tableNamesSet' <= unprocessed. Keys
129+ {
130+ var tableName := tableNamesSeq[i];
131+
132+ var writeRequests : DDB. WriteRequests := unprocessed[tableName];
133+ if ! IsPlainWrite (config, tableName) {
134+ if tableName ! in srcItems {
135+ return Failure (E(tableName + " in UnprocessedItems for BatchWriteItem response, but not in original request."));
136+ }
137+ var srcRequests := srcItems[tableName];
138+ var tableConfig := config. tableEncryptionConfigs[tableName];
139+ var origItems : seq < DDB. WriteRequest> := [];
140+ for x := 0 to |writeRequests|
141+ invariant |origItems| == x
142+ {
143+ var req : DDB. WriteRequest := writeRequests[x];
144+ if req. PutRequest. None? {
145+ // We only transform PutRequests, so no PutRequest ==> no change
146+ origItems := origItems + [req];
147+ } else {
148+ var orig_item :- GetOrigItem (tableConfig, srcRequests, req);
149+ origItems := origItems + [orig_item];
150+ }
151+ }
152+ writeRequests := origItems;
153+ }
154+ tableNamesSet' := tableNamesSet' - {tableName};
155+ i := i + 1;
156+ assert forall j | i <= j < |tableNamesSeq| :: tableNamesSeq[j] in tableNamesSet' by {
157+ reveal Seq. HasNoDuplicates ();
158+ }
159+ result := result[tableName := writeRequests];
160+ }
161+ :- Need (|result| == |unprocessed|, E("Internal Error")); // Dafny gets too confused
162+ var new_output := input. sdkOutput. (UnprocessedItems := Some (result));
163+ return Success (BatchWriteItemOutputTransformOutput(transformedOutput := new_output));
88164 }
89165
90166}
0 commit comments