1
+ import java .lang .reflect .Method ;
2
+ import java .util .HashMap ;
1
3
import java .util .HashSet ;
4
+ import java .util .List ;
5
+ import java .util .Map ;
2
6
import javax .servlet .http .HttpServletRequest ;
3
7
import org .springframework .stereotype .Controller ;
8
+ import org .springframework .util .StringUtils ;
4
9
import org .springframework .web .bind .annotation .GetMapping ;
5
- import org .springframework .web .bind .annotation .ResponseBody ;
10
+ import org .springframework .web .bind .annotation .PathVariable ;
11
+ import org .springframework .web .bind .annotation .RequestBody ;
12
+ import org .springframework .web .bind .annotation .RequestMapping ;
13
+ import org .springframework .web .bind .annotation .RequestMethod ;
14
+ import org .springframework .web .multipart .MultipartFile ;
6
15
7
16
@ Controller
8
17
public class UnsafeReflection {
9
18
10
19
@ GetMapping (value = "uf1" )
11
20
public void bad1 (HttpServletRequest request ) {
12
21
String className = request .getParameter ("className" );
22
+ String parameterValue = request .getParameter ("parameterValue" );
13
23
try {
14
- Class clazz = Class .forName (className ); //bad
15
- } catch (ClassNotFoundException e ) {
24
+ Class clazz = Class .forName (className );
25
+ Object object = clazz .getDeclaredConstructors ()[0 ].newInstance (parameterValue ); //bad
26
+ } catch (Exception e ) {
16
27
e .printStackTrace ();
17
28
}
18
29
}
19
30
20
31
@ GetMapping (value = "uf2" )
21
32
public void bad2 (HttpServletRequest request ) {
22
33
String className = request .getParameter ("className" );
34
+ String parameterValue = request .getParameter ("parameterValue" );
23
35
try {
24
36
ClassLoader classLoader = ClassLoader .getSystemClassLoader ();
25
- Class clazz = classLoader .loadClass (className ); //bad
26
- } catch (ClassNotFoundException e ) {
37
+ Class clazz = classLoader .loadClass (className );
38
+ Object object = clazz .newInstance ();
39
+ clazz .getDeclaredMethods ()[0 ].invoke (object , parameterValue ); //bad
40
+ } catch (Exception e ) {
27
41
e .printStackTrace ();
28
42
}
29
43
}
30
44
45
+ @ RequestMapping (value = {"/service/{beanIdOrClassName}/{methodName}" }, method = {RequestMethod .POST }, consumes = {"application/json" }, produces = {"application/json" })
46
+ public Object bad3 (@ PathVariable ("beanIdOrClassName" ) String beanIdOrClassName , @ PathVariable ("methodName" ) String methodName , @ RequestBody Map <String , Object > body ) throws Exception {
47
+ List <Object > rawData = null ;
48
+ try {
49
+ rawData = (List <Object >)body .get ("methodInput" );
50
+ } catch (Exception e ) {
51
+ return e ;
52
+ }
53
+ return invokeService (beanIdOrClassName , methodName , null , rawData );
54
+ }
55
+
31
56
@ GetMapping (value = "uf3" )
32
57
public void good1 (HttpServletRequest request ) throws Exception {
33
58
HashSet <String > hashSet = new HashSet <>();
34
59
hashSet .add ("com.example.test1" );
35
60
hashSet .add ("com.example.test2" );
36
61
String className = request .getParameter ("className" );
37
- if (hashSet .contains (className )){ //good
62
+ String parameterValue = request .getParameter ("parameterValue" );
63
+ if (!hashSet .contains (className )){
38
64
throw new Exception ("Class not valid: " + className );
39
65
}
40
- ClassLoader classLoader = ClassLoader .getSystemClassLoader ();
41
- Class clazz = classLoader .loadClass (className );
66
+ try {
67
+ Class clazz = Class .forName (className );
68
+ Object object = clazz .getDeclaredConstructors ()[0 ].newInstance (parameterValue ); //good
69
+ } catch (Exception e ) {
70
+ e .printStackTrace ();
71
+ }
42
72
}
43
73
44
74
@ GetMapping (value = "uf4" )
45
75
public void good2 (HttpServletRequest request ) throws Exception {
46
76
String className = request .getParameter ("className" );
47
- if (!"com.example.test1" .equals (className )){ //good
77
+ String parameterValue = request .getParameter ("parameterValue" );
78
+ if (!"com.example.test1" .equals (className )){
48
79
throw new Exception ("Class not valid: " + className );
49
80
}
50
- ClassLoader classLoader = ClassLoader .getSystemClassLoader ();
51
- Class clazz = classLoader .loadClass (className );
81
+ try {
82
+ Class clazz = Class .forName (className );
83
+ Object object = clazz .getDeclaredConstructors ()[0 ].newInstance (parameterValue ); //good
84
+ } catch (Exception e ) {
85
+ e .printStackTrace ();
86
+ }
52
87
}
53
88
54
89
@ GetMapping (value = "uf5" )
55
90
public void good3 (HttpServletRequest request ) throws Exception {
56
91
String className = request .getParameter ("className" );
92
+ String parameterValue = request .getParameter ("parameterValue" );
57
93
if (!className .equals ("com.example.test1" )){ //good
58
94
throw new Exception ("Class not valid: " + className );
59
95
}
60
- ClassLoader classLoader = ClassLoader .getSystemClassLoader ();
61
- Class clazz = classLoader .loadClass (className );
96
+ try {
97
+ Class clazz = Class .forName (className );
98
+ Object object = clazz .getDeclaredConstructors ()[0 ].newInstance (parameterValue ); //good
99
+ } catch (Exception e ) {
100
+ e .printStackTrace ();
101
+ }
102
+ }
103
+
104
+ private Object invokeService (String beanIdOrClassName , String methodName , MultipartFile [] files , List <Object > data ) throws Exception {
105
+ BeanFactory beanFactory = new BeanFactory ();
106
+ try {
107
+ Object bean = null ;
108
+ String beanName = null ;
109
+ Class <?> beanClass = null ;
110
+ try {
111
+ beanClass = Class .forName (beanIdOrClassName );
112
+ beanName = StringUtils .uncapitalize (beanClass .getSimpleName ());
113
+ } catch (ClassNotFoundException classNotFoundException ) {
114
+ beanName = beanIdOrClassName ;
115
+ }
116
+ try {
117
+ bean = beanFactory .getBean (beanName );
118
+ } catch (BeansException beansException ) {
119
+ bean = beanFactory .getBean (beanClass );
120
+ }
121
+ byte b ;
122
+ int i ;
123
+ Method [] arrayOfMethod ;
124
+ for (i = (arrayOfMethod = bean .getClass ().getMethods ()).length , b = 0 ; b < i ; ) {
125
+ Method method = arrayOfMethod [b ];
126
+ if (!method .getName ().equals (methodName )) {
127
+ b ++;
128
+ continue ;
129
+ }
130
+ ProxygenSerializer serializer = new ProxygenSerializer ();
131
+ Object [] methodInput = serializer .deserializeMethodInput (data , files , method );
132
+ Object result = method .invoke (bean , methodInput );
133
+ Map <String , Object > map = new HashMap <>();
134
+ map .put ("result" , serializer .serialize (result ));
135
+ return map ;
136
+ }
137
+ } catch (Exception e ) {
138
+ return e ;
139
+ }
140
+ return null ;
141
+ }
142
+ }
143
+
144
+ class BeansException extends Exception {
145
+
146
+ }
147
+
148
+ class BeanFactory {
149
+
150
+ private static HashMap <String , Object > classNameMap = new HashMap <>();
151
+
152
+ private static HashMap <Class <?>, Object > classMap = new HashMap <>();;
153
+
154
+ static {
155
+ classNameMap .put ("xxxx" , Runtime .getRuntime ());
156
+ classMap .put (Runtime .class , Runtime .getRuntime ());
157
+ }
158
+
159
+ public Object getBean (String className ) throws BeansException {
160
+ if (classNameMap .get (className ) == null ) {
161
+ throw new BeansException ();
162
+ }
163
+ return classNameMap .get (className );
164
+ }
165
+
166
+ public Object getBean (Class <?> clzz ) {
167
+ return classMap .get (clzz );
168
+ }
169
+ }
170
+
171
+ class ProxygenSerializer {
172
+
173
+ public Object [] deserializeMethodInput (List <Object > data , MultipartFile [] files , Method method ) {
174
+ return null ;
175
+ }
176
+
177
+ public String serialize (Object result ) {
178
+ return null ;
62
179
}
63
- }
180
+ }
0 commit comments