Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
* @author Slimane Amar <slimane.amar at rte-france.com>
*/
@Component(value = ContingencyServer.ENDPOINT_NAME)
public class ContingencyServer implements EndPointElementServer {
public class ContingencyServer implements EndPointAccessControlledServer {

public static final String ENDPOINT_NAME = "actions";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@
import static org.springframework.http.HttpMethod.*;

/**
* {@link EndPointServer Server} with access allowed under defined rules
*
* @author Slimane Amar <slimane.amar at rte-france.com>
*/
public interface EndPointElementServer extends EndPointServer {
public interface EndPointAccessControlledServer extends EndPointServer {

String QUERY_PARAM_IDS = "ids";

Set<HttpMethod> ALLOWED_HTTP_METHODS = Set.of(GET, HEAD,
PUT, POST, DELETE
);
Set<HttpMethod> ALLOWED_HTTP_METHODS = Set.of(GET, HEAD, PUT, POST, DELETE);

static UUID getUuid(String uuid) {
try {
Expand All @@ -52,11 +52,6 @@ default boolean isNotControlledRootPath(String rootPath) {
return getUncontrolledRootPaths().contains(rootPath);
}

@Override
default boolean hasElementsAccessControl() {
return true;
}

default Optional<AccessControlInfos> getAccessControlInfos(@NonNull ServerHttpRequest request) {
RequestPath path = Objects.requireNonNull(request.getPath());
UUID elementUuid = getElementUuidIfExist(path);
Expand All @@ -71,7 +66,7 @@ default Optional<AccessControlInfos> getAccessControlInfos(@NonNull ServerHttpRe
return Optional.empty();
} else {
List<String> ids = request.getQueryParams().get(QUERY_PARAM_IDS);
List<UUID> elementUuids = ids.stream().map(EndPointElementServer::getUuid).filter(Objects::nonNull).collect(Collectors.toList());
List<UUID> elementUuids = ids.stream().map(EndPointAccessControlledServer::getUuid).filter(Objects::nonNull).collect(Collectors.toList());
return elementUuids.size() == ids.size() ? Optional.of(AccessControlInfos.create(elementUuids)) : Optional.empty();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
import static org.gridsuite.gateway.GatewayConfig.END_POINT_SERVICE_NAME;

/**
* Declare a service/server accessible on this gateway under path {@code <host_gateway>/<service_name>/*}
* and redirect it to {@code <host_service>/*}.
*
* @author Slimane Amar <slimane.amar at rte-france.com>
*/
public interface EndPointServer {
Expand All @@ -27,8 +30,4 @@ default Buildable<Route> getRoute(@NonNull PredicateSpec p) {
.metadata(END_POINT_SERVICE_NAME, getEndpointName())
.uri(getEndpointBaseUri());
}

default boolean hasElementsAccessControl() {
return false;
}
}
13 changes: 4 additions & 9 deletions src/main/java/org/gridsuite/gateway/endpoints/ExploreServer.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
* @author Slimane Amar <slimane.amar at rte-france.com>
*/
@Component(value = ExploreServer.ENDPOINT_NAME)
public class ExploreServer implements EndPointElementServer {
public class ExploreServer implements EndPointAccessControlledServer {

public static final String ENDPOINT_NAME = "explore";

Expand All @@ -37,7 +37,7 @@ public ExploreServer(ServiceURIsConfig servicesURIsConfig) {

@Override
public UUID getElementUuidIfExist(@NonNull RequestPath path) {
return (path.elements().size() > 7) ? EndPointElementServer.getUuid(path.elements().get(7).value()) : null;
return (path.elements().size() > 7) ? EndPointAccessControlledServer.getUuid(path.elements().get(7).value()) : null;
}

@Override
Expand All @@ -50,11 +50,6 @@ public String getEndpointName() {
return ENDPOINT_NAME;
}

@Override
public boolean hasElementsAccessControl() {
return true;
}

@Override
public Optional<AccessControlInfos> getAccessControlInfos(@NonNull ServerHttpRequest request) {
RequestPath path = Objects.requireNonNull(request.getPath());
Expand All @@ -69,12 +64,12 @@ public Optional<AccessControlInfos> getAccessControlInfos(@NonNull ServerHttpReq
if (ids == null || ids.size() != 1) {
return Optional.empty();
} else {
UUID uuid = EndPointElementServer.getUuid(ids.get(0));
UUID uuid = EndPointAccessControlledServer.getUuid(ids.get(0));
return uuid == null ? Optional.empty() : Optional.of(AccessControlInfos.create(List.of(uuid)));
}
}
} else {
return EndPointElementServer.super.getAccessControlInfos(request);
return EndPointAccessControlledServer.super.getAccessControlInfos(request);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
* @author Slimane Amar <slimane.amar at rte-france.com>
*/
@Component(value = FilterServer.ENDPOINT_NAME)
public class FilterServer implements EndPointElementServer {
public class FilterServer implements EndPointAccessControlledServer {

public static final String ENDPOINT_NAME = "filter";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* @author Slimane Amar <slimane.amar at rte-france.com>
*/
@Component(value = StudyServer.ENDPOINT_NAME)
public class StudyServer implements EndPointElementServer {
public class StudyServer implements EndPointAccessControlledServer {

public static final String ENDPOINT_NAME = "study";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import lombok.NonNull;
import org.gridsuite.gateway.ServiceURIsConfig;
import org.gridsuite.gateway.dto.AccessControlInfos;
import org.gridsuite.gateway.endpoints.EndPointElementServer;
import org.gridsuite.gateway.endpoints.EndPointAccessControlledServer;
import org.gridsuite.gateway.endpoints.EndPointServer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -36,7 +36,7 @@

import static org.gridsuite.gateway.GatewayConfig.END_POINT_SERVICE_NAME;
import static org.gridsuite.gateway.GatewayConfig.HEADER_USER_ID;
import static org.gridsuite.gateway.endpoints.EndPointElementServer.QUERY_PARAM_IDS;
import static org.gridsuite.gateway.endpoints.EndPointAccessControlledServer.QUERY_PARAM_IDS;
import static org.springframework.http.HttpStatus.*;

/**
Expand All @@ -48,11 +48,10 @@ public class ElementAccessControllerGlobalPreFilter extends AbstractGlobalPreFil
private static final Logger LOGGER = LoggerFactory.getLogger(ElementAccessControllerGlobalPreFilter.class);

private static final String ROOT_CATEGORY_REACTOR = "reactor.";

private static final String ELEMENTS_ROOT_PATH = "elements";
private static final Pattern PATH_API_VERSION = Pattern.compile("^/v(\\d)+/.*");

private final WebClient webClient;

private final ApplicationContext applicationContext;

public ElementAccessControllerGlobalPreFilter(ApplicationContext context, ServiceURIsConfig servicesURIsConfig, WebClient.Builder webClientBuilder) {
Expand All @@ -72,31 +71,34 @@ public Mono<Void> filter(@NonNull ServerWebExchange exchange, @NonNull GatewayFi

RequestPath path = exchange.getRequest().getPath();

// Filter only requests to the endpoint servers with this pattern : /v<number>/<appli_root_path>
if (!Pattern.matches("/v(\\d)+/.*", path.value())) {
// Filter only requests to the endpoint servers with this pattern: /v<number>/<appli_root_path>
if (!PATH_API_VERSION.matcher(path.value()).matches()) {
return chain.filter(exchange);
}

// Is an elements' endpoint with a controlled access ?
String endPointServiceName = Objects.requireNonNull((String) (Objects.requireNonNull((Route) exchange.getAttribute(ServerWebExchangeUtils.GATEWAY_ROUTE_ATTR)).getMetadata()).get(END_POINT_SERVICE_NAME));
EndPointServer endPointServer = applicationContext.containsBean(endPointServiceName) ? (EndPointServer) applicationContext.getBean(endPointServiceName) : null;
if (endPointServer == null || !endPointServer.hasElementsAccessControl()) {
// Is an elements' endpoint with controlled access?
final EndPointServer endPointServer = Optional.ofNullable((Route) exchange.getAttribute(ServerWebExchangeUtils.GATEWAY_ROUTE_ATTR))
.map(Route::getMetadata)
.map(metadata -> (String) metadata.get(END_POINT_SERVICE_NAME))
.map(endPointServiceName -> applicationContext.containsBean(endPointServiceName) ? (EndPointServer) applicationContext.getBean(endPointServiceName) : null)
.orElse(null);
if (!(endPointServer instanceof EndPointAccessControlledServer accessControlledServer)) {
return chain.filter(exchange);
}

// Is a root path with a controlled access ?
EndPointElementServer endPointElementServer = (EndPointElementServer) endPointServer;
if (endPointElementServer.isNotControlledRootPath(path.elements().get(3).value())) {
// Is a root path with controlled access?
if (accessControlledServer.isNotControlledRootPath(path.elements().get(3).value())) {
return chain.filter(exchange);
}

// Is a method allowed ?
if (!endPointElementServer.isAllowedMethod(exchange.getRequest().getMethod())) {
// Is a method allowed?
if (!accessControlledServer.isAllowedMethod(exchange.getRequest().getMethod())) {
return completeWithCode(exchange, FORBIDDEN);
}

Optional<AccessControlInfos> accessControlInfos = endPointElementServer.getAccessControlInfos(exchange.getRequest());
return accessControlInfos.isEmpty() ? completeWithCode(exchange, FORBIDDEN) : isAccessAllowed(exchange, chain, accessControlInfos.get());
return accessControlledServer.getAccessControlInfos(exchange.getRequest())
.map(controlInfos -> isAccessAllowed(exchange, chain, controlInfos))
.orElseGet(() -> completeWithCode(exchange, FORBIDDEN));
}

private Mono<Void> isAccessAllowed(ServerWebExchange exchange, GatewayFilterChain chain, AccessControlInfos accessControlInfos) {
Expand Down