Skip to content

Commit fa88990

Browse files
author
‘niuerzhuang’
committed
fix: custom model
1 parent 625f542 commit fa88990

File tree

6 files changed

+85
-128
lines changed

6 files changed

+85
-128
lines changed

dongtai-core/src/main/java/io/dongtai/iast/core/handler/hookpoint/controller/impl/DubboImpl.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,14 @@ public static void solveDubboRequest(Object handler, Object channel, Object requ
4343
}
4444

4545

46-
4746
public static void collectDubboRequestSource(Object handler, Object invocation, String methodName,
4847
Object[] arguments, Map<String, ?> headers,
4948
String hookClass, String hookMethod, String hookSign,
5049
AtomicInteger invokeIdSequencer) {
5150
if (arguments == null || arguments.length == 0) {
5251
return;
5352
}
54-
Map <String, Object> requestMeta = EngineManager.REQUEST_CONTEXT.get();
53+
Map<String, Object> requestMeta = EngineManager.REQUEST_CONTEXT.get();
5554
if (requestMeta == null) {
5655
return;
5756
}
@@ -70,7 +69,7 @@ public static void collectDubboRequestSource(Object handler, Object invocation,
7069
tgt.add(new TaintPosition("P1"));
7170

7271
SourceNode sourceNode = new SourceNode(src, tgt, null);
73-
TaintPoolUtils.trackObject(event, sourceNode, arguments, 0);
72+
TaintPoolUtils.trackObject(event, sourceNode, arguments, 0, true);
7473

7574
Map<String, String> sHeaders = new HashMap<String, String>();
7675
if (headers != null) {

dongtai-core/src/main/java/io/dongtai/iast/core/handler/hookpoint/controller/impl/SourceImpl.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ private static boolean trackTarget(MethodEvent event, SourceNode sourceNode) {
8181
return false;
8282
}
8383

84-
TaintPoolUtils.trackObject(event, sourceNode, event.returnInstance, 0);
84+
TaintPoolUtils.trackObject(event, sourceNode, event.returnInstance, 0, false);
8585
return true;
8686
}
8787

dongtai-core/src/main/java/io/dongtai/iast/core/handler/hookpoint/service/trace/DubboService.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ public class DubboService {
1515
public static void solveSyncInvoke(MethodEvent event, Object invocation, String url, Map<String, String> headers,
1616
AtomicInteger invokeIdSequencer) {
1717
try {
18-
TaintPoolUtils.trackObject(event, null, event.parameterInstances, 0);
18+
TaintPoolUtils.trackObject(event, null, event.parameterInstances, 0, false);
1919
boolean hasTaint = false;
2020
int sourceLen = 0;
2121
if (!event.getSourceHashes().isEmpty()) {
@@ -26,7 +26,7 @@ public static void solveSyncInvoke(MethodEvent event, Object invocation, String
2626

2727
if (headers != null && headers.size() > 0) {
2828
hasTaint = false;
29-
TaintPoolUtils.trackObject(event, null, headers, 0);
29+
TaintPoolUtils.trackObject(event, null, headers, 0, false);
3030
if (event.getSourceHashes().size() > sourceLen) {
3131
hasTaint = true;
3232
}

dongtai-core/src/main/java/io/dongtai/iast/core/handler/hookpoint/service/trace/FeignService.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ public static void solveSyncInvoke(MethodEvent event, AtomicInteger invokeIdSequ
2727

2828
// get args
2929
Object args = event.parameterInstances[0];
30-
TaintPoolUtils.trackObject(event, null, args, 0);
30+
TaintPoolUtils.trackObject(event, null, args, 0, true);
3131

3232
boolean hasTaint = false;
3333
if (!event.getSourceHashes().isEmpty()) {

dongtai-core/src/main/java/io/dongtai/iast/core/utils/ReflectUtils.java

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
import java.lang.reflect.Field;
44
import java.lang.reflect.Method;
5+
import java.security.AccessController;
6+
import java.security.PrivilegedAction;
57
import java.util.*;
68

79
/**
@@ -55,8 +57,18 @@ public static Method getPublicMethodFromClass(Class<?> cls, String method) throw
5557

5658
public static Method getPublicMethodFromClass(Class<?> cls, String methodName, Class<?>[] parameterTypes) throws NoSuchMethodException {
5759
Method method = cls.getMethod(methodName, parameterTypes);
58-
method.setAccessible(true);
59-
return method;
60+
return getSecurityPublicMethod(method);
61+
}
62+
63+
public static Method getSecurityPublicMethod(Method method) throws NoSuchMethodException {
64+
if (hasNotSecurityManager()) {
65+
method.setAccessible(true);
66+
return method;
67+
}
68+
return AccessController.doPrivileged((PrivilegedAction<Method>) () -> {
69+
method.setAccessible(true);
70+
return method;
71+
});
6072
}
6173

6274
public static Method getDeclaredMethodFromClass(Class<?> cls, String methodName, Class<?>[] parameterTypes) {
@@ -66,8 +78,11 @@ public static Method getDeclaredMethodFromClass(Class<?> cls, String methodName,
6678
}
6779
for (Method method : methods) {
6880
if (methodName.equals(method.getName()) && Arrays.equals(parameterTypes, method.getParameterTypes())) {
69-
method.setAccessible(true);
70-
return method;
81+
try {
82+
return getSecurityPublicMethod(method);
83+
} catch (NoSuchMethodException e) {
84+
e.printStackTrace();
85+
}
7186
}
7287
}
7388
return null;
@@ -137,13 +152,35 @@ public static List<Class<?>> getAllInterfaces(Class<?> cls) {
137152
private static void getAllInterfaces(Class<?> cls, List<Class<?>> interfaceList) {
138153
while (cls != null) {
139154
Class<?>[] interfaces = cls.getInterfaces();
140-
for (int i = 0; i < interfaces.length; i++) {
141-
if (!interfaceList.contains(interfaces[i])) {
142-
interfaceList.add(interfaces[i]);
143-
getAllInterfaces(interfaces[i], interfaceList);
155+
for (Class<?> anInterface : interfaces) {
156+
if (!interfaceList.contains(anInterface)) {
157+
interfaceList.add(anInterface);
158+
getAllInterfaces(anInterface, interfaceList);
144159
}
145160
}
146161
cls = cls.getSuperclass();
147162
}
148163
}
164+
165+
public static Field[] getDeclaredFieldsSecurity(Class<?> cls) {
166+
Objects.requireNonNull(cls);
167+
if (hasNotSecurityManager()) {
168+
return getDeclaredFields(cls);
169+
}
170+
return (Field[]) AccessController.doPrivileged((PrivilegedAction<Field[]>) () -> {
171+
return getDeclaredFields(cls);
172+
});
173+
}
174+
175+
private static Field[] getDeclaredFields(Class<?> cls) {
176+
Field[] declaredFields = cls.getDeclaredFields();
177+
for (Field field : declaredFields) {
178+
field.setAccessible(true);
179+
}
180+
return declaredFields;
181+
}
182+
183+
private static boolean hasNotSecurityManager() {
184+
return System.getSecurityManager() == null;
185+
}
149186
}

dongtai-core/src/main/java/io/dongtai/iast/core/utils/TaintPoolUtils.java

Lines changed: 34 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,9 @@
66
import io.dongtai.iast.core.handler.hookpoint.models.policy.SourceNode;
77
import io.dongtai.iast.core.handler.hookpoint.models.taint.range.*;
88
import io.dongtai.log.DongTaiLog;
9-
import io.dongtai.log.ErrorCode;
109

1110
import java.lang.reflect.Array;
12-
import java.lang.reflect.Method;
11+
import java.lang.reflect.Field;
1312
import java.math.BigDecimal;
1413
import java.util.*;
1514

@@ -24,16 +23,6 @@ public class TaintPoolUtils {
2423
private static final String VALUES_ENUMERATOR = " org.apache.tomcat.util.http.ValuesEnumerator".substring(1);
2524
private static final String SPRING_OBJECT = " org.springframework.".substring(1);
2625

27-
/**
28-
* 判断 obj 对象是否为 java 的内置数据类型,包括:string、array、list、map、enum 等
29-
*
30-
* @param obj Object
31-
* @return boolean
32-
*/
33-
public static boolean isJdkType(Object obj) {
34-
return obj instanceof String || obj instanceof Map || obj instanceof List;
35-
}
36-
3726
public static boolean poolContains(Object obj, MethodEvent event) {
3827
if (obj == null) {
3928
return false;
@@ -146,86 +135,7 @@ public static boolean isAllowTaintType(Object obj) {
146135
return isAllowTaintType(obj.getClass());
147136
}
148137

149-
public static boolean isAllowTaintGetterModel(Object model) {
150-
if (!TaintPoolUtils.isNotEmpty(model)) {
151-
return false;
152-
}
153-
Class<?> sourceClass = model.getClass();
154-
if (sourceClass.getClassLoader() == null) {
155-
return false;
156-
}
157-
if (!TaintPoolUtils.isAllowTaintGetterClass(sourceClass)) {
158-
return false;
159-
}
160-
return true;
161-
}
162-
163-
public static boolean isAllowTaintGetterClass(Class<?> clazz) {
164-
String className = clazz.getName();
165-
if (className.startsWith("cn.huoxian.iast.api.") ||
166-
className.startsWith("io.dongtai.api.") ||
167-
className.startsWith(" org.apache.tomcat".substring(1)) ||
168-
className.startsWith(" org.apache.catalina".substring(1)) ||
169-
className.startsWith(" org.apache.shiro.web.servlet".substring(1)) ||
170-
className.startsWith(" org.eclipse.jetty".substring(1)) ||
171-
VALUES_ENUMERATOR.equals(className) ||
172-
className.startsWith(SPRING_OBJECT) ||
173-
className.contains("RequestWrapper") ||
174-
className.contains("ResponseWrapper")
175-
176-
) {
177-
return false;
178-
}
179-
180-
List<Class<?>> interfaces = ReflectUtils.getAllInterfaces(clazz);
181-
for (Class<?> inter : interfaces) {
182-
if (inter.getName().endsWith(".servlet.ServletRequest")
183-
|| inter.getName().endsWith(".servlet.ServletResponse")) {
184-
return false;
185-
}
186-
}
187-
188-
return true;
189-
}
190-
191-
public static boolean isAllowTaintGetterMethod(Method method) {
192-
String methodName = method.getName();
193-
if (!methodName.startsWith("get")
194-
|| "getClass".equals(methodName)
195-
|| "getParserForType".equals(methodName)
196-
|| "getDefaultInstance".equals(methodName)
197-
|| "getDefaultInstanceForType".equals(methodName)
198-
|| "getDescriptor".equals(methodName)
199-
|| "getDescriptorForType".equals(methodName)
200-
|| "getAllFields".equals(methodName)
201-
|| "getInitializationErrorString".equals(methodName)
202-
|| "getUnknownFields".equals(methodName)
203-
|| "getDetailOrBuilderList".equals(methodName)
204-
|| "getAllFieldsMutable".equals(methodName)
205-
|| "getAllFieldsRaw".equals(methodName)
206-
|| "getOneofFieldDescriptor".equals(methodName)
207-
|| "getField".equals(methodName)
208-
|| "getFieldRaw".equals(methodName)
209-
|| "getRepeatedFieldCount".equals(methodName)
210-
|| "getRepeatedField".equals(methodName)
211-
|| "getSerializedSize".equals(methodName)
212-
|| "getMethodOrDie".equals(methodName)
213-
|| "getReader".equals(methodName)
214-
|| "getInputStream".equals(methodName)
215-
|| "getWriter".equals(methodName)
216-
|| "getOutputStream".equals(methodName)
217-
|| "getParameterNames".equals(methodName)
218-
|| "getParameterMap".equals(methodName)
219-
|| "getHeaderNames".equals(methodName)
220-
|| methodName.endsWith("Bytes")
221-
|| method.getParameterCount() != 0) {
222-
return false;
223-
}
224-
225-
return isAllowTaintType(method.getReturnType());
226-
}
227-
228-
public static void trackObject(MethodEvent event, PolicyNode policyNode, Object obj, int depth) {
138+
public static void trackObject(MethodEvent event, PolicyNode policyNode, Object obj, int depth, Boolean isMicroservice) {
229139
if (depth >= 10 || !TaintPoolUtils.isNotEmpty(obj) || !TaintPoolUtils.isAllowTaintType(obj)) {
230140
return;
231141
}
@@ -241,21 +151,21 @@ public static void trackObject(MethodEvent event, PolicyNode policyNode, Object
241151

242152
Class<?> cls = obj.getClass();
243153
if (cls.isArray() && !cls.getComponentType().isPrimitive()) {
244-
trackArray(event, policyNode, obj, depth);
154+
trackArray(event, policyNode, obj, depth, isMicroservice);
245155
} else if (obj instanceof Iterator && !(obj instanceof Enumeration)) {
246-
trackIterator(event, policyNode, (Iterator<?>) obj, depth);
156+
trackIterator(event, policyNode, (Iterator<?>) obj, depth, isMicroservice);
247157
} else if (obj instanceof Map) {
248-
trackMap(event, policyNode, (Map<?, ?>) obj, depth);
158+
trackMap(event, policyNode, (Map<?, ?>) obj, depth, isMicroservice);
249159
} else if (obj instanceof Map.Entry) {
250-
trackMapEntry(event, policyNode, (Map.Entry<?, ?>) obj, depth);
160+
trackMapEntry(event, policyNode, (Map.Entry<?, ?>) obj, depth, isMicroservice);
251161
} else if (obj instanceof Collection && !(obj instanceof Enumeration)) {
252162
if (obj instanceof List) {
253-
trackList(event, policyNode, (List<?>) obj, depth);
163+
trackList(event, policyNode, (List<?>) obj, depth, isMicroservice);
254164
} else {
255-
trackIterator(event, policyNode, ((Collection<?>) obj).iterator(), depth);
165+
trackIterator(event, policyNode, ((Collection<?>) obj).iterator(), depth, isMicroservice);
256166
}
257167
} else if ("java.util.Optional".equals(obj.getClass().getName())) {
258-
trackOptional(event, policyNode, obj, depth);
168+
trackOptional(event, policyNode, obj, depth, isMicroservice);
259169
} else {
260170
if (isSourceNode) {
261171
int len = TaintRangesBuilder.getLength(obj);
@@ -276,6 +186,17 @@ public static void trackObject(MethodEvent event, PolicyNode policyNode, Object
276186
EngineManager.TAINT_HASH_CODES.add(hash);
277187
event.addTargetHash(hash);
278188
EngineManager.TAINT_RANGES_POOL.add(hash, tr);
189+
if (isMicroservice && !(obj instanceof String)) {
190+
try {
191+
Field[] declaredFields = ReflectUtils.getDeclaredFieldsSecurity(cls);
192+
for (Field field : declaredFields) {
193+
trackObject(event, policyNode, field.get(obj), depth + 1, isMicroservice);
194+
}
195+
} catch (Throwable e) {
196+
DongTaiLog.debug("solve model failed: {}, {}",
197+
e.getMessage(), e.getCause() != null ? e.getCause().getMessage() : "");
198+
}
199+
}
279200
} else {
280201
hash = System.identityHashCode(obj);
281202
if (EngineManager.TAINT_HASH_CODES.contains(hash)) {
@@ -285,41 +206,41 @@ public static void trackObject(MethodEvent event, PolicyNode policyNode, Object
285206
}
286207
}
287208

288-
private static void trackArray(MethodEvent event, PolicyNode policyNode, Object arr, int depth) {
209+
private static void trackArray(MethodEvent event, PolicyNode policyNode, Object arr, int depth, Boolean isMicroservice) {
289210
int length = Array.getLength(arr);
290211
for (int i = 0; i < length; i++) {
291-
trackObject(event, policyNode, Array.get(arr, i), depth + 1);
212+
trackObject(event, policyNode, Array.get(arr, i), depth + 1, isMicroservice);
292213
}
293214
}
294215

295-
private static void trackIterator(MethodEvent event, PolicyNode policyNode, Iterator<?> it, int depth) {
216+
private static void trackIterator(MethodEvent event, PolicyNode policyNode, Iterator<?> it, int depth, Boolean isMicroservice) {
296217
while (it.hasNext()) {
297-
trackObject(event, policyNode, it.next(), depth + 1);
218+
trackObject(event, policyNode, it.next(), depth + 1, isMicroservice);
298219
}
299220
}
300221

301-
private static void trackMap(MethodEvent event, PolicyNode policyNode, Map<?, ?> map, int depth) {
222+
private static void trackMap(MethodEvent event, PolicyNode policyNode, Map<?, ?> map, int depth, Boolean isMicroservice) {
302223
for (Object key : map.keySet()) {
303-
trackObject(event, policyNode, key, depth + 1);
304-
trackObject(event, policyNode, map.get(key), depth + 1);
224+
trackObject(event, policyNode, key, depth + 1, isMicroservice);
225+
trackObject(event, policyNode, map.get(key), depth + 1, isMicroservice);
305226
}
306227
}
307228

308-
private static void trackMapEntry(MethodEvent event, PolicyNode policyNode, Map.Entry<?, ?> entry, int depth) {
309-
trackObject(event, policyNode, entry.getKey(), depth + 1);
310-
trackObject(event, policyNode, entry.getValue(), depth + 1);
229+
private static void trackMapEntry(MethodEvent event, PolicyNode policyNode, Map.Entry<?, ?> entry, int depth, Boolean isMicroservice) {
230+
trackObject(event, policyNode, entry.getKey(), depth + 1, isMicroservice);
231+
trackObject(event, policyNode, entry.getValue(), depth + 1, isMicroservice);
311232
}
312233

313-
private static void trackList(MethodEvent event, PolicyNode policyNode, List<?> list, int depth) {
234+
private static void trackList(MethodEvent event, PolicyNode policyNode, List<?> list, int depth, Boolean isMicroservice) {
314235
for (Object obj : list) {
315-
trackObject(event, policyNode, obj, depth + 1);
236+
trackObject(event, policyNode, obj, depth + 1, isMicroservice);
316237
}
317238
}
318239

319-
private static void trackOptional(MethodEvent event, PolicyNode policyNode, Object obj, int depth) {
240+
private static void trackOptional(MethodEvent event, PolicyNode policyNode, Object obj, int depth, Boolean isMicroservice) {
320241
try {
321242
Object v = ((Optional<?>) obj).orElse(null);
322-
trackObject(event, policyNode, v, depth + 1);
243+
trackObject(event, policyNode, v, depth + 1, isMicroservice);
323244
} catch (Throwable ignore) {
324245
}
325246
}

0 commit comments

Comments
 (0)