19
19
package org .apache .gravitino .catalog .jdbc .utils ;
20
20
21
21
import java .sql .SQLException ;
22
+ import java .util .Arrays ;
23
+ import java .util .List ;
22
24
import java .util .Map ;
23
25
import java .util .Properties ;
24
26
import javax .sql .DataSource ;
@@ -37,6 +39,26 @@ public class DataSourceUtils {
37
39
/** SQL statements for database connection pool testing. */
38
40
private static final String POOL_TEST_QUERY = "SELECT 1" ;
39
41
42
+ private static final List <String > unsafeMySQLParameters =
43
+ Arrays .asList (
44
+ "maxAllowedPacket" ,
45
+ "autoDeserialize" ,
46
+ "queryInterceptors" ,
47
+ "statementInterceptors" ,
48
+ "detectCustomCollations" ,
49
+ "allowloadlocalinfile" ,
50
+ "allowUrlInLocalInfile" ,
51
+ "allowLoadLocalInfileInPath" );
52
+
53
+ private static final List <String > unsafePostgresParameters =
54
+ Arrays .asList (
55
+ "socketFactory" ,
56
+ "socketFactoryArg" ,
57
+ "sslfactory" ,
58
+ "sslhostnameverifier" ,
59
+ "sslpasswordcallback" ,
60
+ "authenticationPluginClassName" );
61
+
40
62
public static DataSource createDataSource (Map <String , String > properties ) {
41
63
return createDataSource (new JdbcConfig (properties ));
42
64
}
@@ -51,6 +73,8 @@ public static DataSource createDataSource(JdbcConfig jdbcConfig)
51
73
}
52
74
53
75
private static DataSource createDBCPDataSource (JdbcConfig jdbcConfig ) throws Exception {
76
+ validateJdbcConfig (
77
+ jdbcConfig .getJdbcDriver (), jdbcConfig .getJdbcUrl (), jdbcConfig .getAllConfig ());
54
78
BasicDataSource basicDataSource =
55
79
BasicDataSourceFactory .createDataSource (getProperties (jdbcConfig ));
56
80
String jdbcUrl = jdbcConfig .getJdbcUrl ();
@@ -76,6 +100,61 @@ private static Properties getProperties(JdbcConfig jdbcConfig) {
76
100
return properties ;
77
101
}
78
102
103
+ private static String recursiveDecode (String url ) {
104
+ String prev ;
105
+ String decoded = url ;
106
+ int max = 5 ;
107
+
108
+ do {
109
+ prev = decoded ;
110
+ try {
111
+ decoded = java .net .URLDecoder .decode (prev , "UTF-8" );
112
+ } catch (Exception e ) {
113
+ throw new GravitinoRuntimeException ("Unable to decode JDBC URL" );
114
+ }
115
+ } while (!prev .equals (decoded ) && --max > 0 );
116
+
117
+ return decoded ;
118
+ }
119
+
120
+ private static void checkUnsafeParameters (
121
+ String url , Map <String , String > config , List <String > unsafeParams , String dbType ) {
122
+
123
+ String lowerUrl = url .toLowerCase ();
124
+
125
+ for (String param : unsafeParams ) {
126
+ String lowerParam = param .toLowerCase ();
127
+ if (lowerUrl .contains (lowerParam ) || containsValueIgnoreCase (config , param )) {
128
+ throw new GravitinoRuntimeException (
129
+ "Unsafe %s parameter '%s' detected in JDBC URL" , dbType , param );
130
+ }
131
+ }
132
+ }
133
+
134
+ public static void validateJdbcConfig (String driver , String url , Map <String , String > all ) {
135
+ String lowerUrl = url .toLowerCase ();
136
+ String decodedUrl = recursiveDecode (lowerUrl );
137
+
138
+ if (driver != null ) {
139
+ if (decodedUrl .startsWith ("jdbc:mysql" )) {
140
+ checkUnsafeParameters (decodedUrl , all , unsafeMySQLParameters , "MySQL" );
141
+ } else if (decodedUrl .startsWith ("jdbc:mariadb" )) {
142
+ checkUnsafeParameters (decodedUrl , all , unsafeMySQLParameters , "MariaDB" );
143
+ } else if (decodedUrl .startsWith ("jdbc:postgresql" )) {
144
+ checkUnsafeParameters (decodedUrl , all , unsafePostgresParameters , "PostgreSQL" );
145
+ }
146
+ }
147
+ }
148
+
149
+ private static boolean containsValueIgnoreCase (Map <String , String > map , String value ) {
150
+ for (String keyValue : map .values ()) {
151
+ if (keyValue != null && keyValue .equalsIgnoreCase (value )) {
152
+ return true ;
153
+ }
154
+ }
155
+ return false ;
156
+ }
157
+
79
158
public static void closeDataSource (DataSource dataSource ) {
80
159
if (null != dataSource ) {
81
160
try {
0 commit comments