Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,13 @@
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>io.projectreactor</groupId>
<artifactId>reactor-test</artifactId>
<scope>test</scope>
</dependency>

<!-- compile time only -->
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
Expand Down
43 changes: 9 additions & 34 deletions src/main/java/org/gridsuite/gateway/GatewayConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@
*/
package org.gridsuite.gateway;

import org.gridsuite.gateway.endpoints.CgmesGlServer;
import org.gridsuite.gateway.endpoints.*;
import org.gridsuite.gateway.endpoints.EndPointServer;
import org.springframework.cloud.gateway.route.RouteLocator;
import org.springframework.cloud.gateway.route.builder.RouteLocatorBuilder;
import org.springframework.context.ApplicationContext;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.PropertySource;

import java.util.List;

/**
* @author Chamseddine Benhamed <chamseddine.benhamed at rte-france.com>
* @author Slimane Amar <slimane.amar at rte-france.com>
Expand All @@ -30,36 +30,11 @@ public class GatewayConfig {
public static final String HEADER_CLIENT_ID = "clientId";

@Bean
public RouteLocator myRoutes(RouteLocatorBuilder builder, ApplicationContext context) {
return builder.routes()
.route(p -> context.getBean(StudyServer.class).getRoute(p))
.route(p -> context.getBean(CaseServer.class).getRoute(p))
.route(p -> context.getBean(MergeServer.class).getRoute(p))
.route(p -> context.getBean(StudyNotificationServer.class).getRoute(p))
.route(p -> context.getBean(MergeNotificationServer.class).getRoute(p))
.route(p -> context.getBean(DirectoryNotificationServer.class).getRoute(p))
.route(p -> context.getBean(ContingencyServer.class).getRoute(p))
.route(p -> context.getBean(ConfigServer.class).getRoute(p))
.route(p -> context.getBean(ConfigNotificationServer.class).getRoute(p))
.route(p -> context.getBean(DirectoryServer.class).getRoute(p))
.route(p -> context.getBean(ExploreServer.class).getRoute(p))
.route(p -> context.getBean(CgmesBoundaryServer.class).getRoute(p))
.route(p -> context.getBean(DynamicMappingServer.class).getRoute(p))
.route(p -> context.getBean(FilterServer.class).getRoute(p))
.route(p -> context.getBean(ReportServer.class).getRoute(p))
.route(p -> context.getBean(NetworkModificationServer.class).getRoute(p))
.route(p -> context.getBean(NetworkConversionServer.class).getRoute(p))
.route(p -> context.getBean(OdreServer.class).getRoute(p))
.route(p -> context.getBean(GeoDataServer.class).getRoute(p))
.route(p -> context.getBean(UserAdminServer.class).getRoute(p))
.route(p -> context.getBean(CgmesGlServer.class).getRoute(p))
.route(p -> context.getBean(SensitivityAnalysisServer.class).getRoute(p))
.route(p -> context.getBean(LoadFlowServer.class).getRoute(p))
.route(p -> context.getBean(SecurityAnalysisServer.class).getRoute(p))
.route(p -> context.getBean(DynamicSimulationServer.class).getRoute(p))
.route(p -> context.getBean(CaseImportServer.class).getRoute(p))
.route(p -> context.getBean(VoltageInitServer.class).getRoute(p))
.route(p -> context.getBean(ShortCircuitServer.class).getRoute(p))
.build();
public RouteLocator myRoutes(RouteLocatorBuilder builder, List<EndPointServer> servers) {
final RouteLocatorBuilder.Builder routes = builder.routes();
for (final EndPointServer server : servers) {
routes.route(server.getClass().getName(), server::getRoute);
}
return routes.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,6 @@ public String getEndpointName() {
return ENDPOINT_NAME;
}

@Override
public boolean hasElementsAccessControl() {
return true;
}

@Override
public Optional<AccessControlInfos> getAccessControlInfos(@NonNull ServerHttpRequest request) {
RequestPath path = Objects.requireNonNull(request.getPath());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
/**
* @author Slimane Amar <slimane.amar at rte-france.com>
*/
public abstract class AbstractGlobalPreFilter implements GlobalFilter, Ordered {
abstract class AbstractGlobalPreFilter implements GlobalFilter, Ordered {

protected Mono<Void> completeWithCode(ServerWebExchange exchange, HttpStatus code) {
exchange.getResponse().setStatusCode(code);
if ("websocket".equalsIgnoreCase(exchange.getRequest().getHeaders().getUpgrade())) {
// Force the connection to close for websockets handshakes to workaround apache
// Force the connection to close for websockets handshakes to work around apache
// httpd reusing the connection for all subsequent requests in this connection.
exchange.getResponse().getHeaders().set(HttpHeaders.CONNECTION, "close");
}
Expand Down
106 changes: 106 additions & 0 deletions src/test/java/org/gridsuite/gateway/GatewayApplicationTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
package org.gridsuite.gateway;

import lombok.extern.slf4j.Slf4j;
import org.assertj.core.api.InstanceOfAssertFactories;
import org.assertj.core.api.WithAssertions;
import org.gridsuite.gateway.endpoints.*;
import org.gridsuite.gateway.filters.ElementAccessControllerGlobalPreFilter;
import org.gridsuite.gateway.filters.TokenValidatorGlobalPreFilter;
import org.gridsuite.gateway.filters.UserAdminControlGlobalPreFilter;
import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.cloud.gateway.filter.GlobalFilter;
import org.springframework.cloud.gateway.filter.WebsocketRoutingFilter;
import org.springframework.cloud.gateway.route.RouteLocator;
import org.springframework.context.ApplicationContext;
import org.springframework.core.annotation.AnnotationAwareOrderComparator;
import reactor.test.StepVerifier;

import java.util.Map;

@Slf4j
@SpringBootTest
class GatewayApplicationTest implements WithAssertions {
@Autowired
private ApplicationContext applicationContext;

@Autowired
private RouteLocator myRoutes;

@Test
void testAllEndpointServersFound() {
assertThat(applicationContext.getBeansOfType(EndPointServer.class)).as("found EndPointServer beans")
.allSatisfy((name, srv) -> assertThat(name).isEqualTo(srv.getEndpointName()))
.extracting(Map::values, InstanceOfAssertFactories.collection(EndPointServer.class)).as("EndPointServer beans")
.doesNotHaveDuplicates()
.extracting(EndPointServer::getClass).as("EndPointServer classes")
.containsExactlyInAnyOrder(
CaseImportServer.class,
CaseServer.class,
CgmesBoundaryServer.class,
CgmesGlServer.class,
ConfigNotificationServer.class,
ConfigServer.class,
ContingencyServer.class,
DirectoryNotificationServer.class,
DirectoryServer.class,
DynamicMappingServer.class,
DynamicSimulationServer.class,
ExploreServer.class,
FilterServer.class,
GeoDataServer.class,
LoadFlowServer.class,
MergeNotificationServer.class,
MergeServer.class,
NetworkConversionServer.class,
NetworkModificationServer.class,
OdreServer.class,
ReportServer.class,
SecurityAnalysisServer.class,
SensitivityAnalysisServer.class,
ShortCircuitServer.class,
StudyNotificationServer.class,
StudyServer.class,
UserAdminServer.class,
VoltageInitServer.class
);
}

@Test
void testRoutesInitialized() {
StepVerifier.create(myRoutes.getRoutes())
.as("routes found")
.expectNextCount(28)
.verifyComplete();
}

@Test
void testFiltersOrder() {
assertThat(applicationContext.getBeansOfType(GlobalFilter.class)
.values()
.stream()
.sorted(AnnotationAwareOrderComparator.INSTANCE) //sort work only on bean instances
.peek(f -> log.info("p={} ; o={} ; {}", AAOC.INSTANCE.getPriority(f), AAOC.INSTANCE.getOrder(f), f.getClass().getName()))
.map(GlobalFilter::getClass)
.toList()).as("global filters found")
// Before ElementAccessControllerGlobalPreFilter to enforce authentication
.containsSubsequence(TokenValidatorGlobalPreFilter.class, ElementAccessControllerGlobalPreFilter.class)
// Before WebsocketRoutingFilter to control access
.containsSubsequence(ElementAccessControllerGlobalPreFilter.class, WebsocketRoutingFilter.class)
.containsSubsequence(
TokenValidatorGlobalPreFilter.class, //Ordered.LOWEST_PRECEDENCE - 4
UserAdminControlGlobalPreFilter.class, //Ordered.LOWEST_PRECEDENCE - 3
ElementAccessControllerGlobalPreFilter.class //Ordered.LOWEST_PRECEDENCE - 2
);
}

private static class AAOC extends AnnotationAwareOrderComparator {
public static final AAOC INSTANCE = new AAOC();

@Override
public int getOrder(final Object obj) {
return super.getOrder(obj);
}
}
}
6 changes: 3 additions & 3 deletions src/test/java/org/gridsuite/gateway/TokenValidationTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ public class TokenValidationTest {

private String expiredToken;

private String tokenWithNotAllowedissuer;
private String tokenWithNotAllowedIssuer;

private RSAKey rsaKey;

Expand Down Expand Up @@ -157,7 +157,7 @@ public void prepareToken() throws JOSEException {
token = signedJWT.serialize();
token2 = signedJWT2.serialize();
expiredToken = signedJWTExpired.serialize();
tokenWithNotAllowedissuer = signedJWTWithIssuerNotAllowed.serialize();
tokenWithNotAllowedIssuer = signedJWTWithIssuerNotAllowed.serialize();
}

private void testWebsocket(String name) throws InterruptedException {
Expand Down Expand Up @@ -524,7 +524,7 @@ public void invalidToken() {
//test with with not allowed issuer
webClient
.get().uri("case/v1/cases")
.header("Authorization", "Bearer " + tokenWithNotAllowedissuer)
.header("Authorization", "Bearer " + tokenWithNotAllowedIssuer)
.exchange()
.expectStatus().isEqualTo(401);

Expand Down