Skip to content

Commit e50a648

Browse files
ahmedabu98VardhanThigle
authored andcommitted
[Managed Iceberg] custom equals method for SerializedDataFile (apache#33554)
* add better equals method for SerializedDataFile * add hashcode impl * spotless * add test to check newly added fields; simplify hashcode
1 parent d9928d4 commit e50a648

File tree

3 files changed

+168
-3
lines changed

3 files changed

+168
-3
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
{
22
"comment": "Modify this file in a trivial way to cause this test suite to run.",
3-
"modification": 3
3+
"modification": 2
44
}

sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/SerializableDataFile.java

Lines changed: 91 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,15 @@
2121

2222
import com.google.auto.value.AutoValue;
2323
import java.nio.ByteBuffer;
24+
import java.util.Arrays;
2425
import java.util.HashMap;
2526
import java.util.List;
2627
import java.util.Map;
28+
import java.util.Objects;
2729
import org.apache.beam.sdk.schemas.AutoValueSchema;
2830
import org.apache.beam.sdk.schemas.annotations.DefaultSchema;
31+
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Equivalence;
32+
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps;
2933
import org.apache.iceberg.DataFile;
3034
import org.apache.iceberg.DataFiles;
3135
import org.apache.iceberg.FileFormat;
@@ -41,8 +45,11 @@
4145
* encode/decode it. This class is an identical version that can be used as a PCollection element
4246
* type.
4347
*
44-
* <p>Use {@link #from(DataFile, PartitionKey)} to create a {@link SerializableDataFile} and {@link
45-
* #createDataFile(PartitionSpec)} to reconstruct the original {@link DataFile}.
48+
* <p>NOTE: If you add any new fields here, you need to also update the {@link #equals} and {@link
49+
* #hashCode()} methods.
50+
*
51+
* <p>Use {@link #from(DataFile, String)} to create a {@link SerializableDataFile} and {@link
52+
* #createDataFile(Map)} to reconstruct the original {@link DataFile}.
4653
*/
4754
@DefaultSchema(AutoValueSchema.class)
4855
@AutoValue
@@ -199,4 +206,86 @@ DataFile createDataFile(Map<Integer, PartitionSpec> partitionSpecs) {
199206
}
200207
return output;
201208
}
209+
210+
@Override
211+
public final boolean equals(@Nullable Object o) {
212+
if (this == o) {
213+
return true;
214+
}
215+
if (o == null || getClass() != o.getClass()) {
216+
return false;
217+
}
218+
SerializableDataFile that = (SerializableDataFile) o;
219+
return getPath().equals(that.getPath())
220+
&& getFileFormat().equals(that.getFileFormat())
221+
&& getRecordCount() == that.getRecordCount()
222+
&& getFileSizeInBytes() == that.getFileSizeInBytes()
223+
&& getPartitionPath().equals(that.getPartitionPath())
224+
&& getPartitionSpecId() == that.getPartitionSpecId()
225+
&& Objects.equals(getKeyMetadata(), that.getKeyMetadata())
226+
&& Objects.equals(getSplitOffsets(), that.getSplitOffsets())
227+
&& Objects.equals(getColumnSizes(), that.getColumnSizes())
228+
&& Objects.equals(getValueCounts(), that.getValueCounts())
229+
&& Objects.equals(getNullValueCounts(), that.getNullValueCounts())
230+
&& Objects.equals(getNanValueCounts(), that.getNanValueCounts())
231+
&& mapEquals(getLowerBounds(), that.getLowerBounds())
232+
&& mapEquals(getUpperBounds(), that.getUpperBounds());
233+
}
234+
235+
private static boolean mapEquals(
236+
@Nullable Map<Integer, byte[]> map1, @Nullable Map<Integer, byte[]> map2) {
237+
if (map1 == null && map2 == null) {
238+
return true;
239+
} else if (map1 == null || map2 == null) {
240+
return false;
241+
}
242+
Equivalence<byte[]> byteArrayEquivalence =
243+
new Equivalence<byte[]>() {
244+
@Override
245+
protected boolean doEquivalent(byte[] a, byte[] b) {
246+
return Arrays.equals(a, b);
247+
}
248+
249+
@Override
250+
protected int doHash(byte[] bytes) {
251+
return Arrays.hashCode(bytes);
252+
}
253+
};
254+
255+
return Maps.difference(map1, map2, byteArrayEquivalence).areEqual();
256+
}
257+
258+
@Override
259+
public final int hashCode() {
260+
int hashCode =
261+
Objects.hash(
262+
getPath(),
263+
getFileFormat(),
264+
getRecordCount(),
265+
getFileSizeInBytes(),
266+
getPartitionPath(),
267+
getPartitionSpecId(),
268+
getKeyMetadata(),
269+
getSplitOffsets(),
270+
getColumnSizes(),
271+
getValueCounts(),
272+
getNullValueCounts(),
273+
getNanValueCounts());
274+
hashCode = 31 * hashCode + computeMapByteHashCode(getLowerBounds());
275+
hashCode = 31 * hashCode + computeMapByteHashCode(getUpperBounds());
276+
return hashCode;
277+
}
278+
279+
private static int computeMapByteHashCode(@Nullable Map<Integer, byte[]> map) {
280+
if (map == null) {
281+
return 0;
282+
}
283+
int hashCode = 0;
284+
for (Map.Entry<Integer, byte[]> entry : map.entrySet()) {
285+
int keyHash = entry.getKey().hashCode();
286+
int valueHash = Arrays.hashCode(entry.getValue()); // content-based hash code
287+
hashCode += keyHash ^ valueHash;
288+
}
289+
return hashCode;
290+
}
202291
}
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
package org.apache.beam.sdk.io.iceberg;
19+
20+
import java.lang.reflect.Method;
21+
import java.util.ArrayList;
22+
import java.util.Arrays;
23+
import java.util.List;
24+
import java.util.Set;
25+
import java.util.stream.Collectors;
26+
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet;
27+
import org.junit.Test;
28+
29+
/**
30+
* Test for {@link SerializableDataFile}. More tests can be found in {@link
31+
* org.apache.beam.sdk.io.iceberg.RecordWriterManagerTest}.
32+
*/
33+
public class SerializableDataFileTest {
34+
static final Set<String> FIELDS_SET =
35+
ImmutableSet.<String>builder()
36+
.add("path")
37+
.add("fileFormat")
38+
.add("recordCount")
39+
.add("fileSizeInBytes")
40+
.add("partitionPath")
41+
.add("partitionSpecId")
42+
.add("keyMetadata")
43+
.add("splitOffsets")
44+
.add("columnSizes")
45+
.add("valueCounts")
46+
.add("nullValueCounts")
47+
.add("nanValueCounts")
48+
.add("lowerBounds")
49+
.add("upperBounds")
50+
.build();
51+
52+
@Test
53+
public void testFieldsInEqualsMethodInSyncWithGetterFields() {
54+
List<String> getMethodNames =
55+
Arrays.stream(SerializableDataFile.class.getDeclaredMethods())
56+
.map(Method::getName)
57+
.filter(methodName -> methodName.startsWith("get"))
58+
.collect(Collectors.toList());
59+
60+
List<String> lowerCaseFields =
61+
FIELDS_SET.stream().map(String::toLowerCase).collect(Collectors.toList());
62+
List<String> extras = new ArrayList<>();
63+
for (String field : getMethodNames) {
64+
if (!lowerCaseFields.contains(field.substring(3).toLowerCase())) {
65+
extras.add(field);
66+
}
67+
}
68+
if (!extras.isEmpty()) {
69+
throw new IllegalStateException(
70+
"Detected new field(s) added to SerializableDataFile: "
71+
+ extras
72+
+ "\nPlease include the new field(s) in SerializableDataFile's equals() and hashCode() methods, then add them "
73+
+ "to this test class's FIELDS_SET.");
74+
}
75+
}
76+
}

0 commit comments

Comments
 (0)