Skip to content

Commit 7e576b1

Browse files
Implement equals/hashCode for GraphqlErrorImpl
We have a great many tests that verify that the errors emitted from a `DataFetcher` fit a certain shape. However, verifying equality of these errors is fiddly and error-prone, as we must directly check each individual attribute on every error - this is painful especially when we are trying to perform assertions on a `List` of `GraphQLError`s. To this end, we add `#equals` / `#hashCode` methods on `GraphqlErrorImpl`. However, it is important to note that `equals` will return `true` if all the attributes match, even if the implementing class is _not_ a `GraphqlErrorImpl`. This is done so that different implementations may be swapped in and out with causing test failures.
1 parent d6dbf61 commit 7e576b1

File tree

2 files changed

+77
-1
lines changed

2 files changed

+77
-1
lines changed

src/main/java/graphql/GraphqlErrorBuilder.java

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import java.util.ArrayList;
1010
import java.util.List;
1111
import java.util.Map;
12+
import java.util.Objects;
1213

1314
import static graphql.Assert.assertNotNull;
1415

@@ -132,6 +133,13 @@ public GraphQLError build() {
132133
return new GraphqlErrorImpl(message, locations, errorType, path, extensions);
133134
}
134135

136+
/**
137+
* A simple implementation of a {@link GraphQLError}.
138+
* <p>
139+
* This provides {@link #hashCode()} and {@link #equals(Object)} methods that afford comparison with other
140+
* {@link GraphQLError} implementations. However, the values in the {@link #getExtensions()} {@link Map} <b>must</b>
141+
* in turn implement {@code hashCode()} and {@link #equals(Object)} for this to function correctly.
142+
*/
135143
private static class GraphqlErrorImpl implements GraphQLError {
136144
private final String message;
137145
private final List<SourceLocation> locations;
@@ -176,6 +184,28 @@ public Map<String, Object> getExtensions() {
176184
public String toString() {
177185
return message;
178186
}
187+
188+
@Override
189+
public boolean equals(Object o) {
190+
if (this == o) return true;
191+
if (!(o instanceof GraphQLError)) return false;
192+
GraphQLError that = (GraphQLError) o;
193+
return Objects.equals(getMessage(), that.getMessage())
194+
&& Objects.equals(getLocations(), that.getLocations())
195+
&& Objects.equals(getErrorType(), that.getErrorType())
196+
&& Objects.equals(getPath(), that.getPath())
197+
&& Objects.equals(getExtensions(), that.getExtensions());
198+
}
199+
200+
@Override
201+
public int hashCode() {
202+
return Objects.hash(
203+
getMessage(),
204+
getLocations(),
205+
getErrorType(),
206+
getPath(),
207+
getExtensions());
208+
}
179209
}
180210

181211
/**

src/test/groovy/graphql/GraphqlErrorBuilderTest.groovy

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,4 +152,50 @@ class GraphqlErrorBuilderTest extends Specification {
152152
error.path == null
153153

154154
}
155-
}
155+
156+
def "implements equals correctly"() {
157+
when:
158+
def error1 = GraphQLError.newError().message("msg")
159+
.locations(null)
160+
.extensions([x : "y"])
161+
.path(null)
162+
.build()
163+
def error2 = GraphQLError.newError().message("msg")
164+
.locations(null)
165+
.extensions([x : "y"])
166+
.path(null)
167+
.build()
168+
def error3 = GraphQLError.newError().message("msg3")
169+
.locations(null)
170+
.extensions([x : "y"])
171+
.path(null)
172+
.build()
173+
then:
174+
error1 == error2
175+
error1 != error3
176+
error2 != error3
177+
}
178+
179+
def "implements hashCode correctly"() {
180+
when:
181+
def error1 = GraphQLError.newError().message("msg")
182+
.locations(null)
183+
.extensions([x : "y"])
184+
.path(null)
185+
.build()
186+
def error2 = GraphQLError.newError().message("msg")
187+
.locations(null)
188+
.extensions([x : "y"])
189+
.path(null)
190+
.build()
191+
def error3 = GraphQLError.newError().message("msg3")
192+
.locations(null)
193+
.extensions([x : "y"])
194+
.path(null)
195+
.build()
196+
def errors = [error1, error2, error3] as Set
197+
then:
198+
errors == [error1, error3] as Set
199+
errors == [error2, error3] as Set
200+
}
201+
}

0 commit comments

Comments
 (0)