@@ -170,7 +170,7 @@ def url(self, container_name, object_name):
170
170
171
171
class CloudObjectStorage (object ):
172
172
173
- def __init__ (self , sparkcontext , credentials , configuration_name = '' , bucket_name = '' ):
173
+ def __init__ (self , sparkcontext , credentials , configuration_name = '' , cos_type = 'classic_cos' , auth_method = 'api_key' , bucket_name = '' ):
174
174
175
175
'''
176
176
sparkcontext: a SparkContext object.
@@ -191,10 +191,12 @@ def __init__(self, sparkcontext, credentials, configuration_name='', bucket_name
191
191
you use the url function.
192
192
193
193
'''
194
+ # check if all required values are availble
195
+ _validate_input (credentials , cos_type , auth_method )
196
+
194
197
self .bucket_name = bucket_name
195
198
self .conf_name = configuration_name
196
199
197
- # check if all required values are availble
198
200
credential_key_list = ["endpoint" , "access_key" , "secret_key" ]
199
201
200
202
for i in range (len (credential_key_list )):
@@ -215,6 +217,34 @@ def __init__(self, sparkcontext, credentials, configuration_name='', bucket_name
215
217
hconf .set (prefix + ".access.key" , credentials ['access_key' ])
216
218
hconf .set (prefix + ".secret.key" , credentials ['secret_key' ])
217
219
220
+ def _validate_input (self , credentials , cos_type , auth_method ):
221
+ required_key_classic_cos = ["endpoint" , "access_key" , "secret_key" ]
222
+ required_key_list_iam_api_key = ["endpoint" , "api_key" , "service_id" ]
223
+ required_key_list_iam_token = ["endpoint" , "token" , "service_id" ]
224
+
225
+ def _get_required_keys (cos_type , auth_method ):
226
+ if (cos_type == "bluemix_cos" ):
227
+ if (auth_method == "api_key" ):
228
+ return required_key_list_iam_api_key
229
+ elif (auth_method == "iam_token" )
230
+ return required_key_list_iam_token
231
+ else :
232
+ raise ValueError ("Invalid input: auth_method. auth_method is optional but if set, it should have one of the following values: api_key, iam_token" )
233
+ elif (cos_type == "classic_cos" ):
234
+ return required_key_classic_cos
235
+ else :
236
+ raise ValueError ("Invalid input: cos_type. cos_type is optional but if set, it should have one of the following values: classic_cos, bluemix_cos" )
237
+
238
+ # check keys
239
+ required_key_list = _get_required_keys ()
240
+
241
+ for i in range (len (required_key_list )):
242
+ key = required_key_list [i ]
243
+ if (key not in credentials ):
244
+ raise ValueError ("Invalid input: credentials. {} is required!" .format (key ))
245
+
246
+ return True
247
+
218
248
def url (self , object_name , bucket_name = '' ):
219
249
bucket_name_var = ''
220
250
service_name = DEFAULT_SERVICE_NAME
0 commit comments