diff --git a/pom.xml b/pom.xml
index 6e2fde6..77e079a 100644
--- a/pom.xml
+++ b/pom.xml
@@ -150,6 +150,13 @@
+
+ io.projectreactor
+ reactor-test
+ test
+
+
+
org.projectlombok
lombok
diff --git a/src/main/java/org/gridsuite/gateway/GatewayConfig.java b/src/main/java/org/gridsuite/gateway/GatewayConfig.java
index 741c8d7..7788284 100644
--- a/src/main/java/org/gridsuite/gateway/GatewayConfig.java
+++ b/src/main/java/org/gridsuite/gateway/GatewayConfig.java
@@ -6,15 +6,15 @@
*/
package org.gridsuite.gateway;
-import org.gridsuite.gateway.endpoints.CgmesGlServer;
-import org.gridsuite.gateway.endpoints.*;
+import org.gridsuite.gateway.endpoints.EndPointServer;
import org.springframework.cloud.gateway.route.RouteLocator;
import org.springframework.cloud.gateway.route.builder.RouteLocatorBuilder;
-import org.springframework.context.ApplicationContext;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.PropertySource;
+import java.util.List;
+
/**
* @author Chamseddine Benhamed
* @author Slimane Amar
@@ -30,36 +30,11 @@ public class GatewayConfig {
public static final String HEADER_CLIENT_ID = "clientId";
@Bean
- public RouteLocator myRoutes(RouteLocatorBuilder builder, ApplicationContext context) {
- return builder.routes()
- .route(p -> context.getBean(StudyServer.class).getRoute(p))
- .route(p -> context.getBean(CaseServer.class).getRoute(p))
- .route(p -> context.getBean(MergeServer.class).getRoute(p))
- .route(p -> context.getBean(StudyNotificationServer.class).getRoute(p))
- .route(p -> context.getBean(MergeNotificationServer.class).getRoute(p))
- .route(p -> context.getBean(DirectoryNotificationServer.class).getRoute(p))
- .route(p -> context.getBean(ContingencyServer.class).getRoute(p))
- .route(p -> context.getBean(ConfigServer.class).getRoute(p))
- .route(p -> context.getBean(ConfigNotificationServer.class).getRoute(p))
- .route(p -> context.getBean(DirectoryServer.class).getRoute(p))
- .route(p -> context.getBean(ExploreServer.class).getRoute(p))
- .route(p -> context.getBean(CgmesBoundaryServer.class).getRoute(p))
- .route(p -> context.getBean(DynamicMappingServer.class).getRoute(p))
- .route(p -> context.getBean(FilterServer.class).getRoute(p))
- .route(p -> context.getBean(ReportServer.class).getRoute(p))
- .route(p -> context.getBean(NetworkModificationServer.class).getRoute(p))
- .route(p -> context.getBean(NetworkConversionServer.class).getRoute(p))
- .route(p -> context.getBean(OdreServer.class).getRoute(p))
- .route(p -> context.getBean(GeoDataServer.class).getRoute(p))
- .route(p -> context.getBean(UserAdminServer.class).getRoute(p))
- .route(p -> context.getBean(CgmesGlServer.class).getRoute(p))
- .route(p -> context.getBean(SensitivityAnalysisServer.class).getRoute(p))
- .route(p -> context.getBean(LoadFlowServer.class).getRoute(p))
- .route(p -> context.getBean(SecurityAnalysisServer.class).getRoute(p))
- .route(p -> context.getBean(DynamicSimulationServer.class).getRoute(p))
- .route(p -> context.getBean(CaseImportServer.class).getRoute(p))
- .route(p -> context.getBean(VoltageInitServer.class).getRoute(p))
- .route(p -> context.getBean(ShortCircuitServer.class).getRoute(p))
- .build();
+ public RouteLocator myRoutes(RouteLocatorBuilder builder, List servers) {
+ final RouteLocatorBuilder.Builder routes = builder.routes();
+ for (final EndPointServer server : servers) {
+ routes.route(server.getClass().getName(), server::getRoute);
+ }
+ return routes.build();
}
}
diff --git a/src/main/java/org/gridsuite/gateway/endpoints/ExploreServer.java b/src/main/java/org/gridsuite/gateway/endpoints/ExploreServer.java
index ecf5933..7b19521 100644
--- a/src/main/java/org/gridsuite/gateway/endpoints/ExploreServer.java
+++ b/src/main/java/org/gridsuite/gateway/endpoints/ExploreServer.java
@@ -50,11 +50,6 @@ public String getEndpointName() {
return ENDPOINT_NAME;
}
- @Override
- public boolean hasElementsAccessControl() {
- return true;
- }
-
@Override
public Optional getAccessControlInfos(@NonNull ServerHttpRequest request) {
RequestPath path = Objects.requireNonNull(request.getPath());
diff --git a/src/main/java/org/gridsuite/gateway/filters/AbstractGlobalPreFilter.java b/src/main/java/org/gridsuite/gateway/filters/AbstractGlobalPreFilter.java
index 4103f84..6856ffa 100644
--- a/src/main/java/org/gridsuite/gateway/filters/AbstractGlobalPreFilter.java
+++ b/src/main/java/org/gridsuite/gateway/filters/AbstractGlobalPreFilter.java
@@ -16,12 +16,12 @@
/**
* @author Slimane Amar
*/
-public abstract class AbstractGlobalPreFilter implements GlobalFilter, Ordered {
+abstract class AbstractGlobalPreFilter implements GlobalFilter, Ordered {
protected Mono completeWithCode(ServerWebExchange exchange, HttpStatus code) {
exchange.getResponse().setStatusCode(code);
if ("websocket".equalsIgnoreCase(exchange.getRequest().getHeaders().getUpgrade())) {
- // Force the connection to close for websockets handshakes to workaround apache
+ // Force the connection to close for websockets handshakes to work around apache
// httpd reusing the connection for all subsequent requests in this connection.
exchange.getResponse().getHeaders().set(HttpHeaders.CONNECTION, "close");
}
diff --git a/src/test/java/org/gridsuite/gateway/GatewayApplicationTest.java b/src/test/java/org/gridsuite/gateway/GatewayApplicationTest.java
new file mode 100644
index 0000000..16733ba
--- /dev/null
+++ b/src/test/java/org/gridsuite/gateway/GatewayApplicationTest.java
@@ -0,0 +1,106 @@
+package org.gridsuite.gateway;
+
+import lombok.extern.slf4j.Slf4j;
+import org.assertj.core.api.InstanceOfAssertFactories;
+import org.assertj.core.api.WithAssertions;
+import org.gridsuite.gateway.endpoints.*;
+import org.gridsuite.gateway.filters.ElementAccessControllerGlobalPreFilter;
+import org.gridsuite.gateway.filters.TokenValidatorGlobalPreFilter;
+import org.gridsuite.gateway.filters.UserAdminControlGlobalPreFilter;
+import org.junit.jupiter.api.Test;
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.boot.test.context.SpringBootTest;
+import org.springframework.cloud.gateway.filter.GlobalFilter;
+import org.springframework.cloud.gateway.filter.WebsocketRoutingFilter;
+import org.springframework.cloud.gateway.route.RouteLocator;
+import org.springframework.context.ApplicationContext;
+import org.springframework.core.annotation.AnnotationAwareOrderComparator;
+import reactor.test.StepVerifier;
+
+import java.util.Map;
+
+@Slf4j
+@SpringBootTest
+class GatewayApplicationTest implements WithAssertions {
+ @Autowired
+ private ApplicationContext applicationContext;
+
+ @Autowired
+ private RouteLocator myRoutes;
+
+ @Test
+ void testAllEndpointServersFound() {
+ assertThat(applicationContext.getBeansOfType(EndPointServer.class)).as("found EndPointServer beans")
+ .allSatisfy((name, srv) -> assertThat(name).isEqualTo(srv.getEndpointName()))
+ .extracting(Map::values, InstanceOfAssertFactories.collection(EndPointServer.class)).as("EndPointServer beans")
+ .doesNotHaveDuplicates()
+ .extracting(EndPointServer::getClass).as("EndPointServer classes")
+ .containsExactlyInAnyOrder(
+ CaseImportServer.class,
+ CaseServer.class,
+ CgmesBoundaryServer.class,
+ CgmesGlServer.class,
+ ConfigNotificationServer.class,
+ ConfigServer.class,
+ ContingencyServer.class,
+ DirectoryNotificationServer.class,
+ DirectoryServer.class,
+ DynamicMappingServer.class,
+ DynamicSimulationServer.class,
+ ExploreServer.class,
+ FilterServer.class,
+ GeoDataServer.class,
+ LoadFlowServer.class,
+ MergeNotificationServer.class,
+ MergeServer.class,
+ NetworkConversionServer.class,
+ NetworkModificationServer.class,
+ OdreServer.class,
+ ReportServer.class,
+ SecurityAnalysisServer.class,
+ SensitivityAnalysisServer.class,
+ ShortCircuitServer.class,
+ StudyNotificationServer.class,
+ StudyServer.class,
+ UserAdminServer.class,
+ VoltageInitServer.class
+ );
+ }
+
+ @Test
+ void testRoutesInitialized() {
+ StepVerifier.create(myRoutes.getRoutes())
+ .as("routes found")
+ .expectNextCount(28)
+ .verifyComplete();
+ }
+
+ @Test
+ void testFiltersOrder() {
+ assertThat(applicationContext.getBeansOfType(GlobalFilter.class)
+ .values()
+ .stream()
+ .sorted(AnnotationAwareOrderComparator.INSTANCE) //sort work only on bean instances
+ .peek(f -> log.info("p={} ; o={} ; {}", AAOC.INSTANCE.getPriority(f), AAOC.INSTANCE.getOrder(f), f.getClass().getName()))
+ .map(GlobalFilter::getClass)
+ .toList()).as("global filters found")
+ // Before ElementAccessControllerGlobalPreFilter to enforce authentication
+ .containsSubsequence(TokenValidatorGlobalPreFilter.class, ElementAccessControllerGlobalPreFilter.class)
+ // Before WebsocketRoutingFilter to control access
+ .containsSubsequence(ElementAccessControllerGlobalPreFilter.class, WebsocketRoutingFilter.class)
+ .containsSubsequence(
+ TokenValidatorGlobalPreFilter.class, //Ordered.LOWEST_PRECEDENCE - 4
+ UserAdminControlGlobalPreFilter.class, //Ordered.LOWEST_PRECEDENCE - 3
+ ElementAccessControllerGlobalPreFilter.class //Ordered.LOWEST_PRECEDENCE - 2
+ );
+ }
+
+ private static class AAOC extends AnnotationAwareOrderComparator {
+ public static final AAOC INSTANCE = new AAOC();
+
+ @Override
+ public int getOrder(final Object obj) {
+ return super.getOrder(obj);
+ }
+ }
+}
diff --git a/src/test/java/org/gridsuite/gateway/TokenValidationTest.java b/src/test/java/org/gridsuite/gateway/TokenValidationTest.java
index 693f422..813acff 100644
--- a/src/test/java/org/gridsuite/gateway/TokenValidationTest.java
+++ b/src/test/java/org/gridsuite/gateway/TokenValidationTest.java
@@ -88,7 +88,7 @@ public class TokenValidationTest {
private String expiredToken;
- private String tokenWithNotAllowedissuer;
+ private String tokenWithNotAllowedIssuer;
private RSAKey rsaKey;
@@ -157,7 +157,7 @@ public void prepareToken() throws JOSEException {
token = signedJWT.serialize();
token2 = signedJWT2.serialize();
expiredToken = signedJWTExpired.serialize();
- tokenWithNotAllowedissuer = signedJWTWithIssuerNotAllowed.serialize();
+ tokenWithNotAllowedIssuer = signedJWTWithIssuerNotAllowed.serialize();
}
private void testWebsocket(String name) throws InterruptedException {
@@ -524,7 +524,7 @@ public void invalidToken() {
//test with with not allowed issuer
webClient
.get().uri("case/v1/cases")
- .header("Authorization", "Bearer " + tokenWithNotAllowedissuer)
+ .header("Authorization", "Bearer " + tokenWithNotAllowedIssuer)
.exchange()
.expectStatus().isEqualTo(401);