diff --git a/openam-rest/src/main/java/org/forgerock/openam/cors/CORSService.java b/openam-rest/src/main/java/org/forgerock/openam/cors/CORSService.java index d1893a79da..f8bc9f6955 100644 --- a/openam-rest/src/main/java/org/forgerock/openam/cors/CORSService.java +++ b/openam-rest/src/main/java/org/forgerock/openam/cors/CORSService.java @@ -12,16 +12,20 @@ * information: "Portions copyright [year] [name of copyright owner]". * * Copyright 2014 ForgeRock AS. +* Portions copyright 2024-2025 3A Systems LLC. */ package org.forgerock.openam.cors; import com.sun.identity.shared.debug.Debug; import org.apache.commons.collections4.CollectionUtils; +import org.forgerock.json.JsonValue; +import org.forgerock.json.resource.ResourceException; import org.forgerock.openam.cors.utils.CSVHelper; import org.forgerock.util.Reject; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; +import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -126,8 +130,10 @@ public CORSService(final boolean enabled, List acceptedOrigins, List origins = new ArrayList<>(); origins.add("www.google.com"); ArrayList methods = new ArrayList<>(); @@ -54,6 +58,8 @@ public void setUp() { mockRequest = mock(HttpServletRequest.class); mockResponse = mock(HttpServletResponse.class); + + when(mockResponse.getWriter()).thenReturn(mock(PrintWriter.class)); } @Test(expectedExceptions = IllegalArgumentException.class) @@ -108,7 +114,7 @@ public void EmptyMethodsThrowIllegalArgument() { } @Test - public void shouldNotTouchResponseAsOriginNull() { + public void shouldNotTouchResponseAsOriginNull() throws IOException { //given given(mockRequest.getHeader(CORSConstants.ORIGIN)).willReturn(null); @@ -120,7 +126,7 @@ public void shouldNotTouchResponseAsOriginNull() { } @Test - public void shouldNotTouchResponseAsOriginEmpty() { + public void shouldReturnBadRequestAsOriginEmpty() throws IOException { //given given(mockRequest.getHeader(CORSConstants.ORIGIN)).willReturn(""); @@ -128,11 +134,11 @@ public void shouldNotTouchResponseAsOriginEmpty() { testService.handleRequest(mockRequest, mockResponse); //then - verifyZeroInteractions(mockResponse); + verify(mockResponse, times(1)).setStatus(eq(HttpServletResponse.SC_BAD_REQUEST)); } @Test - public void shouldNotTouchResponseAsOriginInvalid() { + public void shouldReturnBadRequestAsOriginInvalid() throws IOException { //given given(mockRequest.getHeader(CORSConstants.ORIGIN)).willReturn("www.yahoo.com"); @@ -140,11 +146,11 @@ public void shouldNotTouchResponseAsOriginInvalid() { testService.handleRequest(mockRequest, mockResponse); //then - verifyZeroInteractions(mockResponse); + verify(mockResponse, times(1)).setStatus(eq(HttpServletResponse.SC_BAD_REQUEST)); } @Test - public void shouldNotTouchResponseAsOriginCaseInvalid() { + public void shouldReturnBadRequestAsOriginCaseInvalid() throws IOException { //given given(mockRequest.getHeader(CORSConstants.ORIGIN)).willReturn("www.GOOGLE.com"); @@ -152,12 +158,12 @@ public void shouldNotTouchResponseAsOriginCaseInvalid() { testService.handleRequest(mockRequest, mockResponse); //then - verifyZeroInteractions(mockResponse); + verify(mockResponse, times(1)).setStatus(eq(HttpServletResponse.SC_BAD_REQUEST)); } @Test - public void shouldNotTouchResponseAsMethodInvalid() { + public void shouldReturnBadRequestAsMethodInvalid() throws IOException { //given given(mockRequest.getHeader(CORSConstants.ORIGIN)).willReturn("www.google.com"); given(mockRequest.getMethod()).willReturn("PUT"); @@ -166,11 +172,11 @@ public void shouldNotTouchResponseAsMethodInvalid() { testService.handleRequest(mockRequest, mockResponse); //then - verifyZeroInteractions(mockResponse); + verify(mockResponse, times(1)).setStatus(eq(HttpServletResponse.SC_BAD_REQUEST)); } @Test - public void shouldFollowNormalFlowApplyOriginCredsAndExpose() { + public void shouldFollowNormalFlowApplyOriginCredsAndExpose() throws IOException { //given given(mockRequest.getHeader(CORSConstants.ORIGIN)).willReturn("www.google.com"); given(mockRequest.getMethod()).willReturn("POST"); @@ -186,7 +192,7 @@ public void shouldFollowNormalFlowApplyOriginCredsAndExpose() { } @Test - public void shouldFollowNormalFlowJustApplyOrigin() { + public void shouldFollowNormalFlowJustApplyOrigin() throws IOException { //given ArrayList origins = new ArrayList(); origins.add("*"); @@ -206,7 +212,7 @@ public void shouldFollowNormalFlowJustApplyOrigin() { } @Test - public void shouldFollowPreflightFlow() { + public void shouldFollowPreflightFlow() throws IOException { given(mockRequest.getHeader(CORSConstants.ORIGIN)).willReturn("www.google.com"); given(mockRequest.getHeader(CORSConstants.AC_REQUEST_METHOD)).willReturn("POST"); given(mockRequest.getMethod()).willReturn("OPTIONS"); @@ -224,7 +230,7 @@ public void shouldFollowPreflightFlow() { } @Test - public void shouldDoNothingIfPreflightAndNotOptions() { + public void shouldDoNothingIfPreflightAndNotOptions() throws IOException { given(mockRequest.getHeader(CORSConstants.AC_REQUEST_METHOD)).willReturn("POST"); given(mockRequest.getMethod()).willReturn("GET"); @@ -236,7 +242,7 @@ public void shouldDoNothingIfPreflightAndNotOptions() { } @Test - public void shouldDoNothingIfPreflightAndNullRequestMethod() { + public void shouldDoNothingIfPreflightAndNullRequestMethod() throws IOException { given(mockRequest.getHeader(CORSConstants.AC_REQUEST_METHOD)).willReturn(null); given(mockRequest.getMethod()).willReturn("GET"); @@ -248,7 +254,7 @@ public void shouldDoNothingIfPreflightAndNullRequestMethod() { } @Test - public void shouldDoNothingIfPreflightAndEmptyRequestMethod() { + public void shouldDoNothingIfPreflightAndEmptyRequestMethod() throws IOException { given(mockRequest.getHeader(CORSConstants.AC_REQUEST_METHOD)).willReturn(""); given(mockRequest.getMethod()).willReturn("GET"); @@ -260,7 +266,7 @@ public void shouldDoNothingIfPreflightAndEmptyRequestMethod() { } @Test - public void testInvalidHostnameFailsValidation() { + public void testInvalidHostnameFailsValidation() throws IOException { ArrayList origins = new ArrayList(); origins.add("www.google.com"); @@ -282,12 +288,12 @@ public void testInvalidHostnameFailsValidation() { testService.handleRequest(mockRequest, mockResponse); //then - verifyZeroInteractions(mockResponse); + verify(mockResponse, times(1)).setStatus(eq(HttpServletResponse.SC_BAD_REQUEST)); } @Test - public void testHandleNormalIncludesExposedHeadersInResponse() { + public void testHandleNormalIncludesExposedHeadersInResponse() throws IOException { ArrayList origins = new ArrayList(); origins.add("www.google.com");