13
13
# limitations under the License.
14
14
15
15
import collections
16
+ import enum
16
17
import fnmatch
17
18
import re
18
19
import shlex
@@ -200,6 +201,12 @@ def deserialize(cls, model, name, subdirectory, stages):
200
201
return cls (name , subdirectory , task_type , model_stages , description , framework ,
201
202
license_url , precisions , quantization_output_precisions , name )
202
203
204
+ class ModelLoadingMode (enum .Enum ):
205
+ all = 0 # return all models
206
+ composite_only = 1 # return only composite models
207
+ non_composite_only = 2 # return only non composite models
208
+ ignore_composite = 3 # ignore composite structure, return flatten models list
209
+
203
210
def check_composite_model_dir (model_dir ):
204
211
with validation .deserialization_context ('In directory "{}"' .format (model_dir )):
205
212
if list (model_dir .glob ('*/*/**/model.yml' )):
@@ -217,7 +224,7 @@ def check_composite_model_dir(model_dir):
217
224
raise validation .DeserializationError (
218
225
'Names of composite model parts should start with composite model name' )
219
226
220
- def load_models (models_root , args ):
227
+ def load_models (models_root , args , mode = ModelLoadingMode . all ):
221
228
models = []
222
229
model_names = set ()
223
230
@@ -226,67 +233,73 @@ def load_models(models_root, args):
226
233
227
234
schema = _common .get_schema ()
228
235
229
- for composite_model_config in sorted (models_root .glob ('**/composite-model.yml' )):
230
- composite_model_name = composite_model_config .parent .name
231
- with validation .deserialization_context ('In model "{}"' .format (composite_model_name )):
232
- if not RE_MODEL_NAME .fullmatch (composite_model_name ):
233
- raise validation .DeserializationError ('Invalid name, must consist only of letters, digits or ._-' )
236
+ if mode in (ModelLoadingMode .all , ModelLoadingMode .composite_only ):
234
237
235
- check_composite_model_dir (composite_model_config .parent )
238
+ for composite_model_config in sorted (models_root .glob ('**/composite-model.yml' )):
239
+ composite_model_name = composite_model_config .parent .name
240
+ with validation .deserialization_context ('In model "{}"' .format (composite_model_name )):
241
+ if not RE_MODEL_NAME .fullmatch (composite_model_name ):
242
+ raise validation .DeserializationError ('Invalid name, must consist only of letters, digits or ._-' )
236
243
237
- with composite_model_config .open ('rb' ) as config_file , \
238
- validation .deserialization_context ('In config "{}"' .format (composite_model_config )):
244
+ check_composite_model_dir (composite_model_config .parent )
239
245
240
- composite_model = yaml .safe_load (config_file )
241
- model_stages = {}
242
- for stage in sorted (composite_model_config .parent .glob ('*/model.yml' )):
243
- with stage .open ('rb' ) as stage_config_file , \
244
- validation .deserialization_context ('In config "{}"' .format (stage_config_file )):
245
- model = yaml .safe_load (stage_config_file )
246
- if not schema .check (model ):
247
- raise validation .DeserializationError ('Configuration file check was\' t successful.' )
246
+ with composite_model_config .open ('rb' ) as config_file , \
247
+ validation .deserialization_context ('In config "{}"' .format (composite_model_config )):
248
248
249
- stage_subdirectory = stage .parent .relative_to (models_root )
250
- model_stages [stage_subdirectory ] = model
249
+ composite_model = yaml .safe_load (config_file )
250
+ model_stages = {}
251
+ for stage in sorted (composite_model_config .parent .glob ('*/model.yml' )):
252
+ with stage .open ('rb' ) as stage_config_file , \
253
+ validation .deserialization_context ('In config "{}"' .format (stage_config_file )):
254
+ model = yaml .safe_load (stage_config_file )
255
+ if not schema .check (model ):
256
+ raise validation .DeserializationError ('Configuration file check was\' t successful.' )
251
257
252
- if len (model_stages ) == 0 :
253
- continue
254
- subdirectory = composite_model_config .parent .relative_to (models_root )
255
- composite_models .append (CompositeModel .deserialize (
256
- composite_model , composite_model_name , subdirectory , model_stages
257
- ))
258
+ stage_subdirectory = stage .parent .relative_to (models_root )
259
+ model_stages [stage_subdirectory ] = model
258
260
259
- if composite_model_name in composite_model_names :
260
- raise validation .DeserializationError (
261
- 'Duplicate composite model name "{}"' .format (composite_model_name ))
262
- composite_model_names .add (composite_model_name )
261
+ if len (model_stages ) == 0 :
262
+ continue
263
+ subdirectory = composite_model_config .parent .relative_to (models_root )
264
+ composite_models .append (CompositeModel .deserialize (
265
+ composite_model , composite_model_name , subdirectory , model_stages
266
+ ))
263
267
264
- for config_path in sorted (models_root .glob ('**/model.yml' )):
265
- subdirectory = config_path .parent
268
+ if composite_model_name in composite_model_names :
269
+ raise validation .DeserializationError (
270
+ 'Duplicate composite model name "{}"' .format (composite_model_name ))
271
+ composite_model_names .add (composite_model_name )
266
272
267
- is_composite = ( subdirectory . parent / 'composite-model.yml' ). exists ()
268
- if is_composite :
269
- continue
273
+ if mode != ModelLoadingMode . composite_only :
274
+ for config_path in sorted ( models_root . glob ( '**/model.yml' )) :
275
+ subdirectory = config_path . parent
270
276
271
- subdirectory = subdirectory .relative_to (models_root )
277
+ is_composite = (subdirectory .parent / 'composite-model.yml' ).exists ()
278
+ composite_model_name = None
279
+ if is_composite :
280
+ if mode != ModelLoadingMode .ignore_composite :
281
+ continue
282
+ composite_model_name = subdirectory .parent .name
272
283
273
- with config_path .open ('rb' ) as config_file , \
274
- validation .deserialization_context ('In config "{}"' .format (config_path )):
284
+ subdirectory = subdirectory .relative_to (models_root )
275
285
276
- model = yaml .safe_load (config_file )
277
- if not schema .check (model ):
278
- raise validation .DeserializationError ('Configuration file check was\' t successful.' )
286
+ with config_path .open ('rb' ) as config_file , \
287
+ validation .deserialization_context ('In config "{}"' .format (config_path )):
279
288
280
- for bad_key in [ 'name' , 'subdirectory' ]:
281
- if bad_key in model :
282
- raise validation .DeserializationError ('Unsupported key "{}"' . format ( bad_key ) )
289
+ model = yaml . safe_load ( config_file )
290
+ if not schema . check ( model ) :
291
+ raise validation .DeserializationError ('Configuration file check was \' t successful.' )
283
292
284
- models .append (Model .deserialize (model , subdirectory .name , subdirectory , None ))
293
+ for bad_key in ['name' , 'subdirectory' ]:
294
+ if bad_key in model :
295
+ raise validation .DeserializationError ('Unsupported key "{}"' .format (bad_key ))
285
296
286
- if models [- 1 ].name in model_names :
287
- raise validation .DeserializationError (
288
- 'Duplicate model name "{}"' .format (models [- 1 ].name ))
289
- model_names .add (models [- 1 ].name )
297
+ models .append (Model .deserialize (model , subdirectory .name , subdirectory , composite_model_name ))
298
+
299
+ if models [- 1 ].name in model_names :
300
+ raise validation .DeserializationError (
301
+ 'Duplicate model name "{}"' .format (models [- 1 ].name ))
302
+ model_names .add (models [- 1 ].name )
290
303
291
304
return sorted (models + composite_models , key = lambda model : model .name )
292
305
0 commit comments