|
| 1 | +package org.cloudfoundry.identity.uaa; |
| 2 | + |
| 3 | +import jakarta.servlet.DispatcherType; |
| 4 | +import jakarta.servlet.FilterRegistration; |
| 5 | +import jakarta.servlet.ServletContext; |
| 6 | +import jakarta.servlet.ServletRegistration; |
| 7 | +import org.apache.catalina.core.ApplicationContext; |
| 8 | +import org.apache.catalina.core.ApplicationContextFacade; |
| 9 | +import org.apache.catalina.core.StandardContext; |
| 10 | +import org.apache.tomcat.util.descriptor.web.ErrorPage; |
| 11 | +import org.cloudfoundry.identity.uaa.impl.config.YamlServletProfileInitializer; |
| 12 | +import org.springframework.security.web.session.HttpSessionEventPublisher; |
| 13 | +import org.springframework.web.WebApplicationInitializer; |
| 14 | +import org.springframework.web.context.ContextLoaderListener; |
| 15 | +import org.springframework.web.context.support.AnnotationConfigWebApplicationContext; |
| 16 | +import org.springframework.web.filter.DelegatingFilterProxy; |
| 17 | +import org.springframework.web.servlet.DispatcherServlet; |
| 18 | + |
| 19 | +import java.lang.reflect.Field; |
| 20 | +import java.util.EnumSet; |
| 21 | + |
| 22 | +import static org.springframework.util.ReflectionUtils.findField; |
| 23 | +import static org.springframework.util.ReflectionUtils.getField; |
| 24 | + |
| 25 | +public class UaaWebApplicationInitializer implements WebApplicationInitializer { |
| 26 | + @Override |
| 27 | + public void onStartup(ServletContext servletContext) { |
| 28 | + HttpSessionEventPublisher publisher = new HttpSessionEventPublisher(); |
| 29 | + servletContext.addListener(publisher); |
| 30 | + |
| 31 | + AnnotationConfigWebApplicationContext context = new AnnotationConfigWebApplicationContext(); |
| 32 | + context.register(UaaApplicationConfiguration.class); |
| 33 | + context.setServletContext(servletContext); |
| 34 | + ContextLoaderListener contextLoaderListener = new ContextLoaderListener(context); |
| 35 | + contextLoaderListener.setContextInitializers(new YamlServletProfileInitializer()); |
| 36 | + servletContext.addListener(contextLoaderListener); |
| 37 | + |
| 38 | + //<filter-name>springSessionRepositoryFilter</filter-name> |
| 39 | + DelegatingFilterProxy springSessionRepositoryFilter = new DelegatingFilterProxy("springSessionRepositoryFilter", context); |
| 40 | + FilterRegistration.Dynamic springSessionRepositoryFilterRegistration = servletContext.addFilter( |
| 41 | + "springSessionRepositoryFilter", springSessionRepositoryFilter |
| 42 | + ); |
| 43 | + springSessionRepositoryFilterRegistration.addMappingForUrlPatterns( |
| 44 | + EnumSet.of(DispatcherType.REQUEST, DispatcherType.ERROR), false, "/*" |
| 45 | + ); |
| 46 | + |
| 47 | + //<filter-name>aggregateSpringSecurityFilterChain</filter-name> |
| 48 | + DelegatingFilterProxy springSecurityFilterChain = new DelegatingFilterProxy("springSecurityFilterChain", context); |
| 49 | + FilterRegistration.Dynamic springSecurityFilterChainRegistration = servletContext.addFilter( |
| 50 | + "springSecurityFilterChain", springSecurityFilterChain |
| 51 | + ); |
| 52 | + springSecurityFilterChainRegistration.setInitParameter( |
| 53 | + "contextAttribute", "org.springframework.web.servlet.FrameworkServlet.CONTEXT.spring" |
| 54 | + ); |
| 55 | + springSecurityFilterChainRegistration.addMappingForUrlPatterns(null, false, "/*"); |
| 56 | + |
| 57 | + //<servlet-name>spring</servlet-name> |
| 58 | + DispatcherServlet spring = new DispatcherServlet(context); |
| 59 | + spring.setDispatchTraceRequest(false); |
| 60 | + ServletRegistration.Dynamic springRegistration = servletContext.addServlet("spring", spring); |
| 61 | + springRegistration.setLoadOnStartup(1); |
| 62 | + springRegistration.addMapping("/"); |
| 63 | + |
| 64 | + //<error-page> from web.xml |
| 65 | + if (servletContext instanceof ApplicationContextFacade) { |
| 66 | + Field field = findField(ApplicationContextFacade.class, "context", ApplicationContext.class); |
| 67 | + field.setAccessible(true); |
| 68 | + ApplicationContext applicationContext = (ApplicationContext) getField(field, servletContext); |
| 69 | + |
| 70 | + field = findField(ApplicationContext.class, "context", StandardContext.class); |
| 71 | + field.setAccessible(true); |
| 72 | + StandardContext standardContext = (StandardContext) getField(field, applicationContext); |
| 73 | + |
| 74 | + ErrorPage error500 = new ErrorPage(); |
| 75 | + error500.setErrorCode(500); |
| 76 | + error500.setLocation("/error500"); |
| 77 | + standardContext.addErrorPage(error500); |
| 78 | + |
| 79 | + ErrorPage error404 = new ErrorPage(); |
| 80 | + error404.setErrorCode(404); |
| 81 | + error404.setLocation("/error404"); |
| 82 | + standardContext.addErrorPage(error404); |
| 83 | + |
| 84 | + ErrorPage error429 = new ErrorPage(); |
| 85 | + error429.setErrorCode(429); |
| 86 | + error429.setLocation("/error429"); |
| 87 | + standardContext.addErrorPage(error429); |
| 88 | + |
| 89 | + ErrorPage error = new ErrorPage(); |
| 90 | + error.setLocation("/error"); |
| 91 | + standardContext.addErrorPage(error); |
| 92 | + } |
| 93 | + } |
| 94 | +} |
0 commit comments