diff --git a/pom.xml b/pom.xml index 6e2fde6..77e079a 100644 --- a/pom.xml +++ b/pom.xml @@ -150,6 +150,13 @@ + + io.projectreactor + reactor-test + test + + + org.projectlombok lombok diff --git a/src/main/java/org/gridsuite/gateway/GatewayConfig.java b/src/main/java/org/gridsuite/gateway/GatewayConfig.java index 741c8d7..7788284 100644 --- a/src/main/java/org/gridsuite/gateway/GatewayConfig.java +++ b/src/main/java/org/gridsuite/gateway/GatewayConfig.java @@ -6,15 +6,15 @@ */ package org.gridsuite.gateway; -import org.gridsuite.gateway.endpoints.CgmesGlServer; -import org.gridsuite.gateway.endpoints.*; +import org.gridsuite.gateway.endpoints.EndPointServer; import org.springframework.cloud.gateway.route.RouteLocator; import org.springframework.cloud.gateway.route.builder.RouteLocatorBuilder; -import org.springframework.context.ApplicationContext; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.PropertySource; +import java.util.List; + /** * @author Chamseddine Benhamed * @author Slimane Amar @@ -30,36 +30,11 @@ public class GatewayConfig { public static final String HEADER_CLIENT_ID = "clientId"; @Bean - public RouteLocator myRoutes(RouteLocatorBuilder builder, ApplicationContext context) { - return builder.routes() - .route(p -> context.getBean(StudyServer.class).getRoute(p)) - .route(p -> context.getBean(CaseServer.class).getRoute(p)) - .route(p -> context.getBean(MergeServer.class).getRoute(p)) - .route(p -> context.getBean(StudyNotificationServer.class).getRoute(p)) - .route(p -> context.getBean(MergeNotificationServer.class).getRoute(p)) - .route(p -> context.getBean(DirectoryNotificationServer.class).getRoute(p)) - .route(p -> context.getBean(ContingencyServer.class).getRoute(p)) - .route(p -> context.getBean(ConfigServer.class).getRoute(p)) - .route(p -> context.getBean(ConfigNotificationServer.class).getRoute(p)) - .route(p -> context.getBean(DirectoryServer.class).getRoute(p)) - .route(p -> context.getBean(ExploreServer.class).getRoute(p)) - .route(p -> context.getBean(CgmesBoundaryServer.class).getRoute(p)) - .route(p -> context.getBean(DynamicMappingServer.class).getRoute(p)) - .route(p -> context.getBean(FilterServer.class).getRoute(p)) - .route(p -> context.getBean(ReportServer.class).getRoute(p)) - .route(p -> context.getBean(NetworkModificationServer.class).getRoute(p)) - .route(p -> context.getBean(NetworkConversionServer.class).getRoute(p)) - .route(p -> context.getBean(OdreServer.class).getRoute(p)) - .route(p -> context.getBean(GeoDataServer.class).getRoute(p)) - .route(p -> context.getBean(UserAdminServer.class).getRoute(p)) - .route(p -> context.getBean(CgmesGlServer.class).getRoute(p)) - .route(p -> context.getBean(SensitivityAnalysisServer.class).getRoute(p)) - .route(p -> context.getBean(LoadFlowServer.class).getRoute(p)) - .route(p -> context.getBean(SecurityAnalysisServer.class).getRoute(p)) - .route(p -> context.getBean(DynamicSimulationServer.class).getRoute(p)) - .route(p -> context.getBean(CaseImportServer.class).getRoute(p)) - .route(p -> context.getBean(VoltageInitServer.class).getRoute(p)) - .route(p -> context.getBean(ShortCircuitServer.class).getRoute(p)) - .build(); + public RouteLocator myRoutes(RouteLocatorBuilder builder, List servers) { + final RouteLocatorBuilder.Builder routes = builder.routes(); + for (final EndPointServer server : servers) { + routes.route(server.getClass().getName(), server::getRoute); + } + return routes.build(); } } diff --git a/src/main/java/org/gridsuite/gateway/endpoints/ExploreServer.java b/src/main/java/org/gridsuite/gateway/endpoints/ExploreServer.java index ecf5933..7b19521 100644 --- a/src/main/java/org/gridsuite/gateway/endpoints/ExploreServer.java +++ b/src/main/java/org/gridsuite/gateway/endpoints/ExploreServer.java @@ -50,11 +50,6 @@ public String getEndpointName() { return ENDPOINT_NAME; } - @Override - public boolean hasElementsAccessControl() { - return true; - } - @Override public Optional getAccessControlInfos(@NonNull ServerHttpRequest request) { RequestPath path = Objects.requireNonNull(request.getPath()); diff --git a/src/main/java/org/gridsuite/gateway/filters/AbstractGlobalPreFilter.java b/src/main/java/org/gridsuite/gateway/filters/AbstractGlobalPreFilter.java index 4103f84..6856ffa 100644 --- a/src/main/java/org/gridsuite/gateway/filters/AbstractGlobalPreFilter.java +++ b/src/main/java/org/gridsuite/gateway/filters/AbstractGlobalPreFilter.java @@ -16,12 +16,12 @@ /** * @author Slimane Amar */ -public abstract class AbstractGlobalPreFilter implements GlobalFilter, Ordered { +abstract class AbstractGlobalPreFilter implements GlobalFilter, Ordered { protected Mono completeWithCode(ServerWebExchange exchange, HttpStatus code) { exchange.getResponse().setStatusCode(code); if ("websocket".equalsIgnoreCase(exchange.getRequest().getHeaders().getUpgrade())) { - // Force the connection to close for websockets handshakes to workaround apache + // Force the connection to close for websockets handshakes to work around apache // httpd reusing the connection for all subsequent requests in this connection. exchange.getResponse().getHeaders().set(HttpHeaders.CONNECTION, "close"); } diff --git a/src/test/java/org/gridsuite/gateway/GatewayApplicationTest.java b/src/test/java/org/gridsuite/gateway/GatewayApplicationTest.java new file mode 100644 index 0000000..16733ba --- /dev/null +++ b/src/test/java/org/gridsuite/gateway/GatewayApplicationTest.java @@ -0,0 +1,106 @@ +package org.gridsuite.gateway; + +import lombok.extern.slf4j.Slf4j; +import org.assertj.core.api.InstanceOfAssertFactories; +import org.assertj.core.api.WithAssertions; +import org.gridsuite.gateway.endpoints.*; +import org.gridsuite.gateway.filters.ElementAccessControllerGlobalPreFilter; +import org.gridsuite.gateway.filters.TokenValidatorGlobalPreFilter; +import org.gridsuite.gateway.filters.UserAdminControlGlobalPreFilter; +import org.junit.jupiter.api.Test; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.cloud.gateway.filter.GlobalFilter; +import org.springframework.cloud.gateway.filter.WebsocketRoutingFilter; +import org.springframework.cloud.gateway.route.RouteLocator; +import org.springframework.context.ApplicationContext; +import org.springframework.core.annotation.AnnotationAwareOrderComparator; +import reactor.test.StepVerifier; + +import java.util.Map; + +@Slf4j +@SpringBootTest +class GatewayApplicationTest implements WithAssertions { + @Autowired + private ApplicationContext applicationContext; + + @Autowired + private RouteLocator myRoutes; + + @Test + void testAllEndpointServersFound() { + assertThat(applicationContext.getBeansOfType(EndPointServer.class)).as("found EndPointServer beans") + .allSatisfy((name, srv) -> assertThat(name).isEqualTo(srv.getEndpointName())) + .extracting(Map::values, InstanceOfAssertFactories.collection(EndPointServer.class)).as("EndPointServer beans") + .doesNotHaveDuplicates() + .extracting(EndPointServer::getClass).as("EndPointServer classes") + .containsExactlyInAnyOrder( + CaseImportServer.class, + CaseServer.class, + CgmesBoundaryServer.class, + CgmesGlServer.class, + ConfigNotificationServer.class, + ConfigServer.class, + ContingencyServer.class, + DirectoryNotificationServer.class, + DirectoryServer.class, + DynamicMappingServer.class, + DynamicSimulationServer.class, + ExploreServer.class, + FilterServer.class, + GeoDataServer.class, + LoadFlowServer.class, + MergeNotificationServer.class, + MergeServer.class, + NetworkConversionServer.class, + NetworkModificationServer.class, + OdreServer.class, + ReportServer.class, + SecurityAnalysisServer.class, + SensitivityAnalysisServer.class, + ShortCircuitServer.class, + StudyNotificationServer.class, + StudyServer.class, + UserAdminServer.class, + VoltageInitServer.class + ); + } + + @Test + void testRoutesInitialized() { + StepVerifier.create(myRoutes.getRoutes()) + .as("routes found") + .expectNextCount(28) + .verifyComplete(); + } + + @Test + void testFiltersOrder() { + assertThat(applicationContext.getBeansOfType(GlobalFilter.class) + .values() + .stream() + .sorted(AnnotationAwareOrderComparator.INSTANCE) //sort work only on bean instances + .peek(f -> log.info("p={} ; o={} ; {}", AAOC.INSTANCE.getPriority(f), AAOC.INSTANCE.getOrder(f), f.getClass().getName())) + .map(GlobalFilter::getClass) + .toList()).as("global filters found") + // Before ElementAccessControllerGlobalPreFilter to enforce authentication + .containsSubsequence(TokenValidatorGlobalPreFilter.class, ElementAccessControllerGlobalPreFilter.class) + // Before WebsocketRoutingFilter to control access + .containsSubsequence(ElementAccessControllerGlobalPreFilter.class, WebsocketRoutingFilter.class) + .containsSubsequence( + TokenValidatorGlobalPreFilter.class, //Ordered.LOWEST_PRECEDENCE - 4 + UserAdminControlGlobalPreFilter.class, //Ordered.LOWEST_PRECEDENCE - 3 + ElementAccessControllerGlobalPreFilter.class //Ordered.LOWEST_PRECEDENCE - 2 + ); + } + + private static class AAOC extends AnnotationAwareOrderComparator { + public static final AAOC INSTANCE = new AAOC(); + + @Override + public int getOrder(final Object obj) { + return super.getOrder(obj); + } + } +} diff --git a/src/test/java/org/gridsuite/gateway/TokenValidationTest.java b/src/test/java/org/gridsuite/gateway/TokenValidationTest.java index 693f422..813acff 100644 --- a/src/test/java/org/gridsuite/gateway/TokenValidationTest.java +++ b/src/test/java/org/gridsuite/gateway/TokenValidationTest.java @@ -88,7 +88,7 @@ public class TokenValidationTest { private String expiredToken; - private String tokenWithNotAllowedissuer; + private String tokenWithNotAllowedIssuer; private RSAKey rsaKey; @@ -157,7 +157,7 @@ public void prepareToken() throws JOSEException { token = signedJWT.serialize(); token2 = signedJWT2.serialize(); expiredToken = signedJWTExpired.serialize(); - tokenWithNotAllowedissuer = signedJWTWithIssuerNotAllowed.serialize(); + tokenWithNotAllowedIssuer = signedJWTWithIssuerNotAllowed.serialize(); } private void testWebsocket(String name) throws InterruptedException { @@ -524,7 +524,7 @@ public void invalidToken() { //test with with not allowed issuer webClient .get().uri("case/v1/cases") - .header("Authorization", "Bearer " + tokenWithNotAllowedissuer) + .header("Authorization", "Bearer " + tokenWithNotAllowedIssuer) .exchange() .expectStatus().isEqualTo(401);