44import java .io .IOException ;
55import java .lang .reflect .InvocationTargetException ;
66import java .util .ArrayList ;
7+ import java .util .Arrays ;
78import java .util .HashMap ;
9+ import java .util .HashSet ;
810import java .util .Iterator ;
911import java .util .Map ;
10- import org .apache .spark .Partitioner ;
11- import org .apache .spark .sql .catalyst .expressions .Attribute ;
12- import org .apache .spark .sql .catalyst .plans .JoinType ;
12+ import java .util .Set ;
1313import org .apache .spark .sql .catalyst .plans .QueryPlan ;
14- import org .apache .spark .sql .catalyst .plans .physical .BroadcastMode ;
1514import org .apache .spark .sql .catalyst .plans .physical .Partitioning ;
1615import org .apache .spark .sql .catalyst .trees .TreeNode ;
1716import scala .Option ;
@@ -25,16 +24,37 @@ public abstract class AbstractSparkPlanSerializer {
2524 private final int MAX_LENGTH = 50 ;
2625 private final ObjectMapper mapper = AbstractDatadogSparkListener .objectMapper ;
2726
28- private final Class [] SAFE_CLASSES = {
29- Attribute .class , // simpleString appends data type, avoid by using toString
30- JoinType .class , // enum
31- Partitioner .class , // not a product or TreeNode
32- BroadcastMode .class , // not a product or TreeNode
33- maybeGetClass ("org.apache.spark.sql.execution.exchange.ShuffleOrigin" ), // enum (v3+)
34- maybeGetClass ("org.apache.spark.sql.catalyst.optimizer.BuildSide" ), // enum (v3+)
35- maybeGetClass (
36- "org.apache.spark.sql.execution.ShufflePartitionSpec" ), // not a product or TreeNode (v3+)
37- };
27+ private final String SPARK_PKG_NAME = "org.apache.spark" ;
28+ private final Set <String > SAFE_CLASS_NAMES =
29+ new HashSet <>(
30+ Arrays .asList (
31+ SPARK_PKG_NAME + ".Partitioner" , // not a product or TreeNode
32+ SPARK_PKG_NAME
33+ + ".sql.catalyst.expressions.Attribute" , // avoid data type added by simpleString
34+ SPARK_PKG_NAME + ".sql.catalyst.optimizer.BuildSide" , // enum (v3+)
35+ SPARK_PKG_NAME + ".sql.catalyst.plans.JoinType" , // enum
36+ SPARK_PKG_NAME
37+ + ".sql.catalyst.plans.physical.BroadcastMode" , // not a product or TreeNode
38+ SPARK_PKG_NAME
39+ + ".sql.execution.ShufflePartitionSpec" , // not a product or TreeNode (v3+)
40+ SPARK_PKG_NAME + ".sql.execution.exchange.ShuffleOrigin" // enum (v3+)
41+ ));
42+
43+ // Add class here if we want to break inheritance and interface traversal early when we see
44+ // this class. Any class added must be a class whose parents we do not want to match
45+ // (inclusive of the class itself).
46+ private final Set <String > NEGATIVE_CACHE_CLASSES =
47+ new HashSet <>(
48+ Arrays .asList (
49+ "java.io.Serializable" ,
50+ "java.lang.Object" ,
51+ "scala.Equals" ,
52+ "scala.Product" ,
53+ SPARK_PKG_NAME + ".sql.catalyst.InternalRow" ,
54+ SPARK_PKG_NAME + ".sql.catalyst.expressions.Expression" ,
55+ SPARK_PKG_NAME + ".sql.catalyst.expressions.UnaryExpression" ,
56+ SPARK_PKG_NAME + ".sql.catalyst.expressions.Unevaluable" ,
57+ SPARK_PKG_NAME + ".sql.catalyst.trees.TreeNode" ));
3858
3959 public abstract String getKey (int idx , TreeNode node );
4060
@@ -117,7 +137,7 @@ protected Object safeParseObjectToJson(Object value, int depth) {
117137 } else {
118138 return value .toString ();
119139 }
120- } else if (instanceOf (value , SAFE_CLASSES )) {
140+ } else if (traversedInstanceOf (value , SAFE_CLASS_NAMES , NEGATIVE_CACHE_CLASSES )) {
121141 return value .toString ();
122142 } else if (value instanceof TreeNode ) {
123143 // fallback case, leave at bottom
@@ -147,22 +167,44 @@ private String getSimpleString(TreeNode value) {
147167 }
148168 }
149169
150- // Use reflection rather than native `instanceof` for classes added in later Spark versions
151- private boolean instanceOf (Object value , Class [] classes ) {
152- for (Class cls : classes ) {
153- if (cls != null && cls .isInstance (value )) {
170+ private boolean traversedInstanceOf (
171+ Object value , Set <String > expectedClasses , Set <String > negativeCache ) {
172+ if (instanceOf (value .getClass (), expectedClasses , negativeCache )) {
173+ return true ;
174+ }
175+
176+ // Traverse up inheritance tree to check for matches
177+ int lim = 0 ;
178+ Class currClass = value .getClass ();
179+ while (currClass .getSuperclass () != null && lim < MAX_DEPTH ) {
180+ currClass = currClass .getSuperclass ();
181+ if (negativeCache .contains (currClass .getName ())) {
182+ // don't traverse known paths
183+ break ;
184+ }
185+ if (instanceOf (currClass , expectedClasses , negativeCache )) {
154186 return true ;
155187 }
188+ lim += 1 ;
156189 }
157190
158191 return false ;
159192 }
160193
161- private Class maybeGetClass (String cls ) {
162- try {
163- return Class .forName (cls );
164- } catch (ClassNotFoundException e ) {
165- return null ;
194+ private boolean instanceOf (Class cls , Set <String > expectedClasses , Set <String > negativeCache ) {
195+ // Match on strings to avoid class loading errors
196+ if (expectedClasses .contains (cls .getName ())) {
197+ return true ;
166198 }
199+
200+ // Check interfaces as well
201+ for (Class interfaceClass : cls .getInterfaces ()) {
202+ if (!negativeCache .contains (interfaceClass .getName ())
203+ && instanceOf (interfaceClass , expectedClasses , negativeCache )) {
204+ return true ;
205+ }
206+ }
207+
208+ return false ;
167209 }
168210}
0 commit comments