@@ -50,6 +50,8 @@ class RecipeTypes(Enum):
5050 """
5151
5252 ORIGINAL = "original"
53+ SPARSE = "sparse"
54+ TRANSFER = "transfer"
5355 TRANSFER_LEARN = "transfer_learn"
5456
5557
@@ -455,9 +457,6 @@ def search_sparse_recipes(
455457 """
456458 from sparsezoo .objects .model import Model
457459
458- if isinstance (recipe_type , str ):
459- recipe_type = RecipeTypes (recipe_type ).value
460-
461460 if not isinstance (model , Model ):
462461 model = Model .load_model_from_stub (model )
463462
@@ -508,15 +507,21 @@ def recipe_type_original(self) -> bool:
508507 :return: True if this is the original recipe that created the
509508 model, False otherwise
510509 """
511- return self .recipe_type == RecipeTypes .ORIGINAL .value
510+ return any (
511+ self .recipe_type .startswith (start )
512+ for start in [RecipeTypes .ORIGINAL .value , RecipeTypes .SPARSE .value ]
513+ )
512514
513515 @property
514516 def recipe_type_transfer_learn (self ) -> bool :
515517 """
516518 :return: True if this is a recipe for transfer learning from the
517519 created model, False otherwise
518520 """
519- return self .recipe_type == RecipeTypes .TRANSFER_LEARN .value
521+ return any (
522+ self .recipe_type .startswith (start )
523+ for start in [RecipeTypes .TRANSFER .value , RecipeTypes .TRANSFER_LEARN .value ]
524+ )
520525
521526 @property
522527 def display_name (self ):
@@ -653,15 +658,17 @@ def download_base_framework_files(
653658 return base_framework_files or framework_files
654659
655660
656- def _get_stub_args_recipe_type (stub_args : Dict [str , str ]) -> str :
661+ def _get_stub_args_recipe_type (stub_args : Dict [str , str ]) -> Optional [ str ] :
657662 # check recipe type, default to original, and validate
658663 recipe_type = stub_args .get ("recipe_type" )
659-
660- # validate
661664 valid_recipe_types = list (map (lambda typ : typ .value , RecipeTypes ))
662- if recipe_type not in valid_recipe_types and recipe_type is not None :
665+
666+ if recipe_type is not None and not any (
667+ recipe_type .startswith (start ) for start in valid_recipe_types
668+ ):
663669 raise ValueError (
664670 f"Invalid recipe_type: '{ recipe_type } '. "
665- f"Valid recipe types : { valid_recipe_types } "
671+ f"Valid recipes must start with one of : { valid_recipe_types } "
666672 )
673+
667674 return recipe_type
0 commit comments