Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,20 @@
* information: "Portions copyright [year] [name of copyright owner]".
*
* Copyright 2014 ForgeRock AS.
* Portions copyright 2024-2025 3A Systems LLC.
*/
package org.forgerock.openam.cors;

import com.sun.identity.shared.debug.Debug;
import org.apache.commons.collections4.CollectionUtils;
import org.forgerock.json.JsonValue;
import org.forgerock.json.resource.ResourceException;
import org.forgerock.openam.cors.utils.CSVHelper;
import org.forgerock.util.Reject;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
Expand Down Expand Up @@ -126,8 +130,10 @@ public CORSService(final boolean enabled, List<String> acceptedOrigins, List<Str
* @param req CORS HTTP request
* @param res HTTP response
* @return true if the caller is to continue processing the request
*
* @throws IOException if an I/O error occurs while handling the response
*/
public boolean handleRequest(final HttpServletRequest req, final HttpServletResponse res) {
public boolean handleRequest(final HttpServletRequest req, final HttpServletResponse res) throws IOException {
if(!this.enabled) {
return true;
}
Expand All @@ -136,6 +142,7 @@ public boolean handleRequest(final HttpServletRequest req, final HttpServletResp
}

if (!isValidCORSRequest(req)) {
handleFailedCORS(res);
return false;
}

Expand All @@ -148,6 +155,22 @@ public boolean handleRequest(final HttpServletRequest req, final HttpServletResp

}

/**
* Handles a failed CORS (Cross-Origin Resource Sharing) request by generating a standardized
* JSON error response and setting the appropriate HTTP status code.
*
* @param res the {@link HttpServletResponse} to which the error response will be written
* @throws IOException if an I/O error occurs while writing the response
*/
private void handleFailedCORS(HttpServletResponse res) throws IOException {
ResourceException resourceException = ResourceException.getException(HttpServletResponse.SC_BAD_REQUEST, "CORS error occurred");
JsonValue jsonValue = resourceException.toJsonValue();
res.setStatus(resourceException.getCode());
res.setContentType("application/json");
res.setCharacterEncoding("UTF-8");
res.getWriter().write(jsonValue.toString());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just keep in mind that JsonValue#toString doesn't always return actual proper JSON. If you call it on something like a Map value (result of json.get(field)), it will just print out the "value", not a valid JSON.

}

/**
* Handles the preflight flow.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
* information: "Portions copyright [year] [name of copyright owner]".
*
* Copyright 2014 ForgeRock AS.
* Portions Copyrighted 2024 3A Systems LLC.
* Portions Copyrighted 2024-2025 3A Systems LLC.
*/
package org.forgerock.openam.cors;

import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayList;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
Expand All @@ -27,6 +29,8 @@
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyZeroInteractions;
import static org.mockito.Mockito.when;

import org.testng.annotations.BeforeMethod;
import org.testng.annotations.Test;

Expand All @@ -37,7 +41,7 @@ public class CORSServiceTest {
private HttpServletResponse mockResponse;

@BeforeMethod
public void setUp() {
public void setUp() throws IOException {
ArrayList<String> origins = new ArrayList<>();
origins.add("www.google.com");
ArrayList<String> methods = new ArrayList<>();
Expand All @@ -54,6 +58,8 @@ public void setUp() {

mockRequest = mock(HttpServletRequest.class);
mockResponse = mock(HttpServletResponse.class);

when(mockResponse.getWriter()).thenReturn(mock(PrintWriter.class));
}

@Test(expectedExceptions = IllegalArgumentException.class)
Expand Down Expand Up @@ -108,7 +114,7 @@ public void EmptyMethodsThrowIllegalArgument() {
}

@Test
public void shouldNotTouchResponseAsOriginNull() {
public void shouldNotTouchResponseAsOriginNull() throws IOException {
//given
given(mockRequest.getHeader(CORSConstants.ORIGIN)).willReturn(null);

Expand All @@ -120,44 +126,44 @@ public void shouldNotTouchResponseAsOriginNull() {
}

@Test
public void shouldNotTouchResponseAsOriginEmpty() {
public void shouldReturnBadRequestAsOriginEmpty() throws IOException {
//given
given(mockRequest.getHeader(CORSConstants.ORIGIN)).willReturn("");

//when
testService.handleRequest(mockRequest, mockResponse);

//then
verifyZeroInteractions(mockResponse);
verify(mockResponse, times(1)).setStatus(eq(HttpServletResponse.SC_BAD_REQUEST));
}

@Test
public void shouldNotTouchResponseAsOriginInvalid() {
public void shouldReturnBadRequestAsOriginInvalid() throws IOException {
//given
given(mockRequest.getHeader(CORSConstants.ORIGIN)).willReturn("www.yahoo.com");

//when
testService.handleRequest(mockRequest, mockResponse);

//then
verifyZeroInteractions(mockResponse);
verify(mockResponse, times(1)).setStatus(eq(HttpServletResponse.SC_BAD_REQUEST));
}

@Test
public void shouldNotTouchResponseAsOriginCaseInvalid() {
public void shouldReturnBadRequestAsOriginCaseInvalid() throws IOException {
//given
given(mockRequest.getHeader(CORSConstants.ORIGIN)).willReturn("www.GOOGLE.com");

//when
testService.handleRequest(mockRequest, mockResponse);

//then
verifyZeroInteractions(mockResponse);
verify(mockResponse, times(1)).setStatus(eq(HttpServletResponse.SC_BAD_REQUEST));
}


@Test
public void shouldNotTouchResponseAsMethodInvalid() {
public void shouldReturnBadRequestAsMethodInvalid() throws IOException {
//given
given(mockRequest.getHeader(CORSConstants.ORIGIN)).willReturn("www.google.com");
given(mockRequest.getMethod()).willReturn("PUT");
Expand All @@ -166,11 +172,11 @@ public void shouldNotTouchResponseAsMethodInvalid() {
testService.handleRequest(mockRequest, mockResponse);

//then
verifyZeroInteractions(mockResponse);
verify(mockResponse, times(1)).setStatus(eq(HttpServletResponse.SC_BAD_REQUEST));
}

@Test
public void shouldFollowNormalFlowApplyOriginCredsAndExpose() {
public void shouldFollowNormalFlowApplyOriginCredsAndExpose() throws IOException {
//given
given(mockRequest.getHeader(CORSConstants.ORIGIN)).willReturn("www.google.com");
given(mockRequest.getMethod()).willReturn("POST");
Expand All @@ -186,7 +192,7 @@ public void shouldFollowNormalFlowApplyOriginCredsAndExpose() {
}

@Test
public void shouldFollowNormalFlowJustApplyOrigin() {
public void shouldFollowNormalFlowJustApplyOrigin() throws IOException {
//given
ArrayList<String> origins = new ArrayList<String>();
origins.add("*");
Expand All @@ -206,7 +212,7 @@ public void shouldFollowNormalFlowJustApplyOrigin() {
}

@Test
public void shouldFollowPreflightFlow() {
public void shouldFollowPreflightFlow() throws IOException {
given(mockRequest.getHeader(CORSConstants.ORIGIN)).willReturn("www.google.com");
given(mockRequest.getHeader(CORSConstants.AC_REQUEST_METHOD)).willReturn("POST");
given(mockRequest.getMethod()).willReturn("OPTIONS");
Expand All @@ -224,7 +230,7 @@ public void shouldFollowPreflightFlow() {
}

@Test
public void shouldDoNothingIfPreflightAndNotOptions() {
public void shouldDoNothingIfPreflightAndNotOptions() throws IOException {
given(mockRequest.getHeader(CORSConstants.AC_REQUEST_METHOD)).willReturn("POST");
given(mockRequest.getMethod()).willReturn("GET");

Expand All @@ -236,7 +242,7 @@ public void shouldDoNothingIfPreflightAndNotOptions() {
}

@Test
public void shouldDoNothingIfPreflightAndNullRequestMethod() {
public void shouldDoNothingIfPreflightAndNullRequestMethod() throws IOException {
given(mockRequest.getHeader(CORSConstants.AC_REQUEST_METHOD)).willReturn(null);
given(mockRequest.getMethod()).willReturn("GET");

Expand All @@ -248,7 +254,7 @@ public void shouldDoNothingIfPreflightAndNullRequestMethod() {
}

@Test
public void shouldDoNothingIfPreflightAndEmptyRequestMethod() {
public void shouldDoNothingIfPreflightAndEmptyRequestMethod() throws IOException {
given(mockRequest.getHeader(CORSConstants.AC_REQUEST_METHOD)).willReturn("");
given(mockRequest.getMethod()).willReturn("GET");

Expand All @@ -260,7 +266,7 @@ public void shouldDoNothingIfPreflightAndEmptyRequestMethod() {
}

@Test
public void testInvalidHostnameFailsValidation() {
public void testInvalidHostnameFailsValidation() throws IOException {

ArrayList<String> origins = new ArrayList<String>();
origins.add("www.google.com");
Expand All @@ -282,12 +288,12 @@ public void testInvalidHostnameFailsValidation() {
testService.handleRequest(mockRequest, mockResponse);

//then
verifyZeroInteractions(mockResponse);
verify(mockResponse, times(1)).setStatus(eq(HttpServletResponse.SC_BAD_REQUEST));
}


@Test
public void testHandleNormalIncludesExposedHeadersInResponse() {
public void testHandleNormalIncludesExposedHeadersInResponse() throws IOException {

ArrayList<String> origins = new ArrayList<String>();
origins.add("www.google.com");
Expand Down