Skip to content

Commit adc5b4b

Browse files
gscheibelbrennantaylor
authored andcommitted
Prevent type from being recreated when mixing interfaces and unions (#62)
1 parent c1249b5 commit adc5b4b

File tree

3 files changed

+19
-5
lines changed

3 files changed

+19
-5
lines changed

src/main/kotlin/com/expedia/graphql/schema/generator/SchemaGenerator.kt

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ internal class SchemaGenerator(
271271
}
272272

273273
private fun interfaceType(kClass: KClass<*>): GraphQLType {
274-
return state.cache.buildIfNotUnderConstruction(kClass) { _ ->
274+
return state.cache.buildIfNotUnderConstruction(kClass) {
275275
val builder = GraphQLInterfaceType.newInterface()
276276

277277
builder.name(kClass.simpleName)
@@ -293,16 +293,19 @@ internal class SchemaGenerator(
293293
val objectType = objectType(it.kotlin, interfaceType)
294294
val key = TypesCacheKey(it.kotlin.createType(), false)
295295

296-
state.additionalTypes.add(objectType)
297296
state.cache.put(key, KGraphQLType(it.kotlin, objectType))
297+
if (objectType !is GraphQLTypeReference) {
298+
state.additionalTypes.add(objectType)
299+
}
300+
state.cache.removeTypeUnderConstruction(it.kotlin)
298301
}
299302

300303
interfaceType
301304
}
302305
}
303306

304307
private fun unionType(kClass: KClass<*>): GraphQLType {
305-
return state.cache.buildIfNotUnderConstruction(kClass) { _ ->
308+
return state.cache.buildIfNotUnderConstruction(kClass) {
306309
val builder = GraphQLUnionType.newUnionType()
307310

308311
builder.name(kClass.simpleName)
@@ -313,7 +316,8 @@ internal class SchemaGenerator(
313316
implementations
314317
.filterNot { it.kotlin.isAbstract }
315318
.forEach {
316-
val objectType = objectType(it.kotlin)
319+
val objectType = state.cache.get(TypesCacheKey(it.kotlin.createType(), false)) ?: objectType(it.kotlin)
320+
317321
val key = TypesCacheKey(it.kotlin.createType(), false)
318322

319323
if (objectType is GraphQLTypeReference) {
@@ -323,6 +327,9 @@ internal class SchemaGenerator(
323327
}
324328

325329
state.cache.put(key, KGraphQLType(it.kotlin, objectType))
330+
if (state.cache.doesNotContain(it.kotlin)) {
331+
state.cache.put(key, KGraphQLType(it.kotlin, objectType))
332+
}
326333
}
327334

328335
builder.build()

src/main/kotlin/com/expedia/graphql/schema/generator/TypesCache.kt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ internal class TypesCache(private val supportedPackages: List<String>) {
4444
fun doesNotContainGraphQLType(graphQLType: GraphQLType) =
4545
cache.none { (_, v) -> v.graphQLType.name == graphQLType.name }
4646

47+
fun doesNotContain(kClass: KClass<*>): Boolean = cache.none { (_, ktype) -> ktype.kClass == kClass }
48+
4749
@Throws(CouldNotGetNameOfEnumException::class)
4850
private fun getCacheKeyString(cacheKey: TypesCacheKey): String {
4951
val kClass = getKClassFromKType(cacheKey.type)
@@ -91,6 +93,8 @@ internal class TypesCache(private val supportedPackages: List<String>) {
9193

9294
private fun putTypeUnderConstruction(kClass: KClass<*>) = typeUnderConstruction.add(kClass)
9395

96+
fun removeTypeUnderConstruction(kClass: KClass<*>) = typeUnderConstruction.remove(kClass)
97+
9498
fun buildIfNotUnderConstruction(kClass: KClass<*>, build: (KClass<*>) -> GraphQLType): GraphQLType {
9599
return if (typeUnderConstruction.contains(kClass)) {
96100
GraphQLTypeReference.typeRef(kClass.simpleName)

src/test/kotlin/com/expedia/graphql/schema/generator/PolymorphicTests.kt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ class PolymorphicTests {
8989
}
9090

9191
class QueryWithInterface {
92+
fun fromUnion(): Union = AnImplementation()
9293
fun query(): AnInterface = AnImplementation()
9394
fun fromImplementation(): AnImplementation = AnImplementation()
9495
}
@@ -101,14 +102,16 @@ class QueryWithUnAuthorizedUnionArgument {
101102
fun notAllowed(body: BodyPart): BodyPart = body
102103
}
103104

105+
interface Union
106+
104107
interface AnInterface {
105108
val property: String
106109
}
107110

108111
data class AnImplementation(
109112
override val property: String = "A value",
110113
val implementationSpecific: String = "It's implementation specific"
111-
) : AnInterface
114+
) : AnInterface, Union
112115

113116
class QueryWithUnion {
114117
fun query(whichHand: String): BodyPart = when (whichHand) {

0 commit comments

Comments
 (0)