diff --git a/src/main/java/org/dataloader/DataLoaderRegistry.java b/src/main/java/org/dataloader/DataLoaderRegistry.java index 6bc79f6..7a0e0e6 100644 --- a/src/main/java/org/dataloader/DataLoaderRegistry.java +++ b/src/main/java/org/dataloader/DataLoaderRegistry.java @@ -1,6 +1,7 @@ package org.dataloader; import org.dataloader.annotations.PublicApi; +import org.dataloader.errors.StrictModeRegistryException; import org.dataloader.impl.Assertions; import org.dataloader.instrumentation.ChainedDataLoaderInstrumentation; import org.dataloader.instrumentation.DataLoaderInstrumentation; @@ -20,6 +21,7 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.function.Function; +import static java.lang.String.format; import static org.dataloader.impl.Assertions.assertState; /** @@ -43,21 +45,28 @@ @PublicApi @NullMarked public class DataLoaderRegistry { + protected final Map> dataLoaders; protected final @Nullable DataLoaderInstrumentation instrumentation; + protected final boolean strictMode; public DataLoaderRegistry() { - this(new ConcurrentHashMap<>(), null); + this(new ConcurrentHashMap<>(), null, false); } private DataLoaderRegistry(Builder builder) { - this(builder.dataLoaders, builder.instrumentation); + this(builder.dataLoaders, builder.instrumentation, builder.strictMode); } - protected DataLoaderRegistry(Map> dataLoaders, @Nullable DataLoaderInstrumentation instrumentation) { + protected DataLoaderRegistry( + Map> dataLoaders, + @Nullable DataLoaderInstrumentation instrumentation, + boolean strictMode + ) { this.dataLoaders = instrumentDLs(dataLoaders, instrumentation); this.instrumentation = instrumentation; + this.strictMode = strictMode; } private Map> instrumentDLs(Map> incomingDataLoaders, @Nullable DataLoaderInstrumentation registryInstrumentation) { @@ -76,7 +85,10 @@ protected DataLoaderRegistry(Map> dataLoaders, @Nullabl * @param existingDL the existing data loader * @return a new {@link DataLoader} or the same one if there is nothing to change */ - private static DataLoader nameAndInstrumentDL(String key, @Nullable DataLoaderInstrumentation registryInstrumentation, DataLoader existingDL) { + private DataLoader nameAndInstrumentDL(String key, @Nullable DataLoaderInstrumentation registryInstrumentation, DataLoader existingDL) { + if (strictMode) { + assertKeyStrictly(key); + } existingDL = checkAndSetName(key, existingDL); if (registryInstrumentation == null) { @@ -214,7 +226,9 @@ public DataLoader computeIfAbsent(final String key, * @return a new combined registry */ public DataLoaderRegistry combine(DataLoaderRegistry registry) { - DataLoaderRegistry combined = new DataLoaderRegistry(); + DataLoaderRegistry combined = new Builder() + .strictMode(strictMode) + .build(); this.dataLoaders.forEach(combined::register); registry.dataLoaders.forEach(combined::register); @@ -312,6 +326,12 @@ public Statistics getStatistics() { return stats; } + protected void assertKeyStrictly(String key) { + if (dataLoaders.containsKey(key)) { + throw new StrictModeRegistryException(format("The key %s already has a DataLoader defined", key)); + } + } + /** * @return A builder of {@link DataLoaderRegistry}s */ @@ -323,6 +343,19 @@ public static class Builder { private final Map> dataLoaders = new HashMap<>(); private @Nullable DataLoaderInstrumentation instrumentation; + private boolean strictMode; + + /** + * This puts the builder into strict mode, so if things get defined twice, for example, it + * will throw a {@link org.dataloader.errors.StrictModeRegistryException}. + * + * @param strictMode whether strict mode is enabled + * @return this builder + */ + public Builder strictMode(boolean strictMode) { + this.strictMode = strictMode; + return this; + } /** * This will register a new dataloader @@ -332,6 +365,9 @@ public static class Builder { * @return this builder for a fluent pattern */ public Builder register(String key, DataLoader dataLoader) { + if (strictMode) { + assertKeyStrictly(key); + } dataLoaders.put(key, dataLoader); return this; } @@ -344,6 +380,11 @@ public Builder register(String key, DataLoader dataLoader) { * @return this builder for a fluent pattern */ public Builder registerAll(DataLoaderRegistry otherRegistry) { + if (strictMode) { + otherRegistry.dataLoaders.forEach((key, dataLoader) -> { + assertKeyStrictly(key); + }); + } dataLoaders.putAll(otherRegistry.dataLoaders); return this; } @@ -353,6 +394,12 @@ public Builder instrumentation(DataLoaderInstrumentation instrumentation) { return this; } + private void assertKeyStrictly(String key) { + if (dataLoaders.containsKey(key)) { + throw new StrictModeRegistryException(format("The key %s already has a DataLoader defined", key)); + } + } + /** * @return the newly built {@link DataLoaderRegistry} */ diff --git a/src/main/java/org/dataloader/errors/StrictModeRegistryException.java b/src/main/java/org/dataloader/errors/StrictModeRegistryException.java new file mode 100644 index 0000000..a21c99b --- /dev/null +++ b/src/main/java/org/dataloader/errors/StrictModeRegistryException.java @@ -0,0 +1,14 @@ +package org.dataloader.errors; + +import org.dataloader.annotations.PublicApi; + +/** + * An exception that is thrown when {@link org.dataloader.DataLoaderRegistry.Builder#strictMode(boolean)} is true and multiple + * DataLoaders are registered to the same key. + */ +@PublicApi +public class StrictModeRegistryException extends RuntimeException { + public StrictModeRegistryException(String msg) { + super(msg); + } +} diff --git a/src/main/java/org/dataloader/registries/ScheduledDataLoaderRegistry.java b/src/main/java/org/dataloader/registries/ScheduledDataLoaderRegistry.java index 4f62378..7bb9485 100644 --- a/src/main/java/org/dataloader/registries/ScheduledDataLoaderRegistry.java +++ b/src/main/java/org/dataloader/registries/ScheduledDataLoaderRegistry.java @@ -3,6 +3,7 @@ import org.dataloader.DataLoader; import org.dataloader.DataLoaderRegistry; import org.dataloader.annotations.ExperimentalApi; +import org.dataloader.errors.StrictModeRegistryException; import org.dataloader.impl.Assertions; import org.dataloader.instrumentation.DataLoaderInstrumentation; import org.jspecify.annotations.NullMarked; @@ -16,6 +17,7 @@ import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; +import static java.lang.String.format; import static org.dataloader.impl.Assertions.nonNull; /** @@ -69,7 +71,7 @@ public class ScheduledDataLoaderRegistry extends DataLoaderRegistry implements A private volatile boolean closed; private ScheduledDataLoaderRegistry(Builder builder) { - super(builder.dataLoaders, builder.instrumentation); + super(builder.dataLoaders, builder.instrumentation, builder.strictMode); this.scheduledExecutorService = Assertions.nonNull(builder.scheduledExecutorService); this.defaultExecutorUsed = builder.defaultExecutorUsed; this.schedule = builder.schedule; @@ -120,7 +122,8 @@ public boolean isTickerMode() { */ public ScheduledDataLoaderRegistry combine(DataLoaderRegistry registry) { Builder combinedBuilder = ScheduledDataLoaderRegistry.newScheduledRegistry() - .dispatchPredicate(this.dispatchPredicate); + .dispatchPredicate(this.dispatchPredicate) + .strictMode(this.strictMode); combinedBuilder.registerAll(this); combinedBuilder.registerAll(registry); return combinedBuilder.build(); @@ -166,6 +169,9 @@ public DispatchPredicate getDispatchPredicate() { * @return this registry */ public ScheduledDataLoaderRegistry register(String key, DataLoader dataLoader, DispatchPredicate dispatchPredicate) { + if (strictMode) { + assertKeyStrictly(key); + } dataLoaders.put(key, dataLoader); dataLoaderPredicates.put(dataLoader, dispatchPredicate); return this; @@ -272,6 +278,7 @@ public static class Builder { private Duration schedule = Duration.ofMillis(10); private boolean tickerMode = false; private @Nullable DataLoaderInstrumentation instrumentation; + private boolean strictMode; /** @@ -291,6 +298,18 @@ public Builder schedule(Duration schedule) { return this; } + /** + * This puts the builder into strict mode, so if things get defined twice, for example, it + * will throw a {@link org.dataloader.errors.StrictModeRegistryException}. + * + * @param strictMode whether strict mode is enabled + * @return this builder + */ + public Builder strictMode(boolean strictMode) { + this.strictMode = strictMode; + return this; + } + /** * This will register a new dataloader * @@ -299,6 +318,9 @@ public Builder schedule(Duration schedule) { * @return this builder for a fluent pattern */ public Builder register(String key, DataLoader dataLoader) { + if (strictMode) { + assertKeyStrictly(key); + } dataLoaders.put(key, dataLoader); return this; } @@ -326,7 +348,13 @@ public Builder register(String key, DataLoader dataLoader, DispatchPredica * @return this builder for a fluent pattern */ public Builder registerAll(DataLoaderRegistry otherRegistry) { - dataLoaders.putAll(otherRegistry.getDataLoadersMap()); + Map> otherDataLoaders = otherRegistry.getDataLoadersMap(); + if (strictMode) { + otherDataLoaders.forEach((key, dataLoader) -> { + assertKeyStrictly(key); + }); + } + dataLoaders.putAll(otherDataLoaders); if (otherRegistry instanceof ScheduledDataLoaderRegistry) { ScheduledDataLoaderRegistry other = (ScheduledDataLoaderRegistry) otherRegistry; dataLoaderPredicates.putAll(other.dataLoaderPredicates); @@ -364,6 +392,12 @@ public Builder instrumentation(DataLoaderInstrumentation instrumentation) { return this; } + private void assertKeyStrictly(String key) { + if (dataLoaders.containsKey(key)) { + throw new StrictModeRegistryException(format("The key %s already has a DataLoader defined", key)); + } + } + /** * @return the newly built {@link ScheduledDataLoaderRegistry} */ diff --git a/src/test/java/org/dataloader/DataLoaderRegistryTest.java b/src/test/java/org/dataloader/DataLoaderRegistryTest.java index 89624d7..8138333 100644 --- a/src/test/java/org/dataloader/DataLoaderRegistryTest.java +++ b/src/test/java/org/dataloader/DataLoaderRegistryTest.java @@ -1,5 +1,6 @@ package org.dataloader; +import org.dataloader.errors.StrictModeRegistryException; import org.dataloader.stats.SimpleStatisticsCollector; import org.dataloader.stats.Statistics; import org.junit.jupiter.api.Assertions; @@ -14,6 +15,7 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasItems; import static org.hamcrest.Matchers.sameInstance; +import static org.junit.jupiter.api.Assertions.assertThrows; public class DataLoaderRegistryTest { final BatchLoader identityBatchLoader = CompletableFuture::completedFuture; @@ -219,4 +221,48 @@ public void builder_works() { assertThat(registry.getDataLoader("c"), equalTo(dlC)); } + + @Test + public void strictMode_works() { + + DataLoader dlA = newDataLoader(identityBatchLoader); + DataLoader dlB = newDataLoader(identityBatchLoader); + + assertThrows(StrictModeRegistryException.class, () -> { + DataLoaderRegistry.newRegistry() + .strictMode(true) + .register("a", dlA) + .register("a", dlB) + .build(); + }); + assertThrows(StrictModeRegistryException.class, () -> { + DataLoaderRegistry.newRegistry() + .strictMode(true) + .register("a", dlA) + .registerAll(DataLoaderRegistry.newRegistry() + .register("a", dlB) + .build()) + .build(); + }); + + DataLoaderRegistry registry = DataLoaderRegistry.newRegistry() + .strictMode(true) + .build(); + registry.register("a", dlA); + + assertThrows(StrictModeRegistryException.class, () -> { + registry.register("a", dlB); + }); + assertThrows(StrictModeRegistryException.class, () -> { + registry.register(newDataLoader("a", identityBatchLoader)); + }); + assertThrows(StrictModeRegistryException.class, () -> { + registry.registerAndGet("a", dlB); + }); + assertThrows(StrictModeRegistryException.class, () -> { + registry.combine(DataLoaderRegistry.newRegistry() + .register("a", dlB) + .build()); + }); + } } diff --git a/src/test/java/org/dataloader/registries/ScheduledDataLoaderRegistryTest.java b/src/test/java/org/dataloader/registries/ScheduledDataLoaderRegistryTest.java index e89939c..873d6bb 100644 --- a/src/test/java/org/dataloader/registries/ScheduledDataLoaderRegistryTest.java +++ b/src/test/java/org/dataloader/registries/ScheduledDataLoaderRegistryTest.java @@ -3,6 +3,7 @@ import org.awaitility.core.ConditionTimeoutException; import org.dataloader.DataLoader; import org.dataloader.DataLoaderRegistry; +import org.dataloader.errors.StrictModeRegistryException; import org.dataloader.fixtures.parameterized.TestDataLoaderFactory; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; @@ -22,12 +23,14 @@ import static java.util.Collections.singletonList; import static org.awaitility.Awaitility.await; import static org.awaitility.Duration.TWO_SECONDS; +import static org.dataloader.DataLoaderFactory.newDataLoader; import static org.dataloader.fixtures.TestKit.snooze; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; @@ -371,4 +374,48 @@ public void executors_are_shutdown() { } + + @Test + public void strictMode_works() { + + DataLoader dlA = newDataLoader(CompletableFuture::completedFuture); + DataLoader dlB = newDataLoader(CompletableFuture::completedFuture); + + assertThrows(StrictModeRegistryException.class, () -> { + ScheduledDataLoaderRegistry.newRegistry() + .strictMode(true) + .register("a", dlA) + .register("a", dlB) + .build(); + }); + assertThrows(StrictModeRegistryException.class, () -> { + ScheduledDataLoaderRegistry.newRegistry() + .strictMode(true) + .register("a", dlA) + .registerAll(ScheduledDataLoaderRegistry.newRegistry() + .register("a", dlB) + .build()) + .build(); + }); + + DataLoaderRegistry registry = ScheduledDataLoaderRegistry.newRegistry() + .strictMode(true) + .build(); + registry.register("a", dlA); + + assertThrows(StrictModeRegistryException.class, () -> { + registry.register("a", dlB); + }); + assertThrows(StrictModeRegistryException.class, () -> { + registry.register(newDataLoader("a", CompletableFuture::completedFuture)); + }); + assertThrows(StrictModeRegistryException.class, () -> { + registry.registerAndGet("a", dlB); + }); + assertThrows(StrictModeRegistryException.class, () -> { + registry.combine(ScheduledDataLoaderRegistry.newRegistry() + .register("a", dlB) + .build()); + }); + } }