1+ package com .baeldung .graphql .fileupload ;
2+
3+ import org .apache .commons .logging .Log ;
4+ import org .apache .commons .logging .LogFactory ;
5+ import org .springframework .http .HttpHeaders ;
6+ import org .springframework .http .HttpInputMessage ;
7+ import org .springframework .http .converter .GenericHttpMessageConverter ;
8+ import org .springframework .util .StringUtils ;
9+ import org .springframework .web .multipart .MultipartFile ;
10+ import org .springframework .web .multipart .MultipartHttpServletRequest ;
11+ import reactor .core .publisher .Mono ;
12+
13+ import org .springframework .context .i18n .LocaleContextHolder ;
14+ import org .springframework .core .ParameterizedTypeReference ;
15+ import org .springframework .graphql .server .WebGraphQlHandler ;
16+ import org .springframework .graphql .server .WebGraphQlRequest ;
17+ import org .springframework .http .MediaType ;
18+ import org .springframework .util .AlternativeJdkIdGenerator ;
19+ import org .springframework .util .Assert ;
20+ import org .springframework .util .IdGenerator ;
21+ import org .springframework .web .servlet .function .ServerRequest ;
22+ import org .springframework .web .servlet .function .ServerResponse ;
23+
24+ import javax .servlet .ServletException ;
25+ import javax .servlet .http .HttpServletRequest ;
26+ import javax .servlet .http .Part ;
27+ import java .io .IOException ;
28+ import java .io .InputStream ;
29+ import java .lang .reflect .Type ;
30+ import java .util .*;
31+
32+ import static org .springframework .http .MediaType .APPLICATION_GRAPHQL ;
33+
34+ public class MultipartGraphQlHttpHandler {
35+
36+ private static final Log logger = LogFactory .getLog (MultipartGraphQlHttpHandler .class );
37+
38+ private static final ParameterizedTypeReference <Map <String , Object >> MAP_PARAMETERIZED_TYPE_REF = new ParameterizedTypeReference <Map <String , Object >>() {
39+ };
40+
41+ private static final ParameterizedTypeReference <Map <String , List <String >>> LIST_PARAMETERIZED_TYPE_REF = new ParameterizedTypeReference <Map <String , List <String >>>() {
42+ };
43+
44+ public static final List <MediaType > SUPPORTED_MEDIA_TYPES = Arrays .asList (APPLICATION_GRAPHQL , MediaType .APPLICATION_JSON , MediaType .APPLICATION_GRAPHQL );
45+
46+ private final IdGenerator idGenerator = new AlternativeJdkIdGenerator ();
47+
48+ private final WebGraphQlHandler graphQlHandler ;
49+
50+ private final GenericHttpMessageConverter genericHttpMessageConverter ;
51+
52+ public MultipartGraphQlHttpHandler (WebGraphQlHandler graphQlHandler , GenericHttpMessageConverter genericHttpMessageConverter ) {
53+ Assert .notNull (graphQlHandler , "WebGraphQlHandler is required" );
54+ Assert .notNull (genericHttpMessageConverter , "GenericHttpMessageConverter is required" );
55+ this .graphQlHandler = graphQlHandler ;
56+ this .genericHttpMessageConverter = genericHttpMessageConverter ;
57+ }
58+
59+ public ServerResponse handleMultipartRequest (ServerRequest serverRequest ) throws ServletException {
60+ HttpServletRequest httpServletRequest = serverRequest .servletRequest ();
61+
62+ Map <String , Object > inputQuery = Optional .ofNullable (this .<Map <String , Object >>deserializePart (httpServletRequest , "operations" , MAP_PARAMETERIZED_TYPE_REF .getType ())).orElse (new HashMap <>());
63+
64+ final Map <String , Object > queryVariables = getFromMapOrEmpty (inputQuery , "variables" );
65+ final Map <String , Object > extensions = getFromMapOrEmpty (inputQuery , "extensions" );
66+
67+ Map <String , MultipartFile > fileParams = readMultipartFiles (httpServletRequest );
68+
69+ Map <String , List <String >> fileMappings = Optional .ofNullable (this .<Map <String , List <String >>>deserializePart (httpServletRequest , "map" , LIST_PARAMETERIZED_TYPE_REF .getType ())).orElse (new HashMap <>());
70+
71+ fileMappings .forEach ((String fileKey , List <String > objectPaths ) -> {
72+ MultipartFile file = fileParams .get (fileKey );
73+ if (file != null ) {
74+ objectPaths .forEach ((String objectPath ) -> {
75+ MultipartVariableMapper .mapVariable (objectPath , queryVariables , file );
76+ });
77+ }
78+ });
79+
80+ String query = (String ) inputQuery .get ("query" );
81+ String opName = (String ) inputQuery .get ("operationName" );
82+
83+ Map <String , Object > body = new HashMap <>();
84+ body .put ("query" , query );
85+ body .put ("operationName" , StringUtils .hasText (opName ) ? opName : "" );
86+ body .put ("variables" , queryVariables );
87+ body .put ("extensions" , extensions );
88+
89+ WebGraphQlRequest graphQlRequest = new WebGraphQlRequest (serverRequest .uri (), serverRequest .headers ().asHttpHeaders (), body , this .idGenerator .generateId ().toString (), LocaleContextHolder .getLocale ());
90+
91+ if (logger .isDebugEnabled ()) {
92+ logger .debug ("Executing: " + graphQlRequest );
93+ }
94+
95+ Mono <ServerResponse > responseMono = this .graphQlHandler .handleRequest (graphQlRequest ).map (response -> {
96+ if (logger .isDebugEnabled ()) {
97+ logger .debug ("Execution complete" );
98+ }
99+ ServerResponse .BodyBuilder builder = ServerResponse .ok ();
100+ builder .headers (headers -> headers .putAll (response .getResponseHeaders ()));
101+ builder .contentType (selectResponseMediaType (serverRequest ));
102+ return builder .body (response .toMap ());
103+ });
104+
105+ return ServerResponse .async (responseMono );
106+ }
107+
108+ private static class JsonMultipartInputMessage implements HttpInputMessage {
109+
110+ private final Part part ;
111+
112+ JsonMultipartInputMessage (Part part ) {
113+ this .part = part ;
114+ }
115+
116+ @ Override
117+ public InputStream getBody () throws IOException {
118+ return this .part .getInputStream ();
119+ }
120+
121+ @ Override
122+ public HttpHeaders getHeaders () {
123+ HttpHeaders httpHeaders = new HttpHeaders ();
124+ httpHeaders .setContentType (MediaType .APPLICATION_JSON );
125+ return httpHeaders ;
126+ }
127+ }
128+
129+ private <T > T deserializePart (HttpServletRequest httpServletRequest , String name , Type type ) {
130+ try {
131+ Part part = httpServletRequest .getPart (name );
132+ if (part == null ) {
133+ return null ;
134+ }
135+ return (T ) this .genericHttpMessageConverter .read (type , null , new JsonMultipartInputMessage (part ));
136+ } catch (IOException | ServletException e ) {
137+ throw new RuntimeException (e );
138+ }
139+ }
140+
141+ @ SuppressWarnings ("unchecked" )
142+ private Map <String , Object > getFromMapOrEmpty (Map <String , Object > input , String key ) {
143+ if (input .containsKey (key )) {
144+ return (Map <String , Object >) input .get (key );
145+ } else {
146+ return new HashMap <>();
147+ }
148+ }
149+
150+ private static Map <String , MultipartFile > readMultipartFiles (HttpServletRequest httpServletRequest ) {
151+ Assert .isInstanceOf (MultipartHttpServletRequest .class , httpServletRequest , "Request should be of type MultipartHttpServletRequest" );
152+ MultipartHttpServletRequest multipartHttpServletRequest = (MultipartHttpServletRequest ) httpServletRequest ;
153+ return multipartHttpServletRequest .getFileMap ();
154+ }
155+
156+ private static MediaType selectResponseMediaType (ServerRequest serverRequest ) {
157+ for (MediaType accepted : serverRequest .headers ().accept ()) {
158+ if (SUPPORTED_MEDIA_TYPES .contains (accepted )) {
159+ return accepted ;
160+ }
161+ }
162+ return MediaType .APPLICATION_JSON ;
163+ }
164+
165+ }
0 commit comments