@@ -65,7 +65,9 @@ def download_model_and_recipe(root_dir: str):
6565 Download pretrained model and a pruning recipe
6666 """
6767 model_dir = os .path .join (root_dir , "resnet20_v1" )
68- zoo_model = Zoo .load_model (
68+
69+ # Load base model to prune
70+ base_zoo_model = Zoo .load_model (
6971 domain = "cv" ,
7072 sub_domain = "classification" ,
7173 architecture = "resnet_v1" ,
@@ -74,18 +76,21 @@ def download_model_and_recipe(root_dir: str):
7476 repo = "sparseml" ,
7577 dataset = "cifar_10" ,
7678 training_scheme = None ,
77- optim_name = "pruned " ,
78- optim_category = "conservative" ,
79+ optim_name = "base " ,
80+ optim_category = None ,
7981 optim_target = None ,
8082 override_parent_path = model_dir ,
8183 )
82- zoo_model .download ()
83- model_file_path = zoo_model .framework_files [0 ].downloaded_path ()
84+ base_zoo_model .download ()
85+ model_file_path = base_zoo_model .framework_files [0 ].downloaded_path ()
8486 if not os .path .exists (model_file_path ) or not model_file_path .endswith (".h5" ):
8587 raise RuntimeError ("Model file not found: {}" .format (model_file_path ))
86- recipe_file_path = zoo_model .recipes [0 ].downloaded_path ()
87- if not os .path .exists (recipe_file_path ):
88- raise RuntimeError ("Recipe file not found: {}" .format (recipe_file_path ))
88+
89+ # Simply use the recipe stub
90+ recipe_file_path = (
91+ "zoo:cv/classification/resnet_v1-20/keras/sparseml/cifar_10/pruned-conservative"
92+ )
93+
8994 return model_file_path , recipe_file_path
9095
9196
@@ -132,6 +137,7 @@ def main():
132137 (X_train , y_train ), (X_test , y_test ) = load_and_normalize_cifar10 ()
133138
134139 model_file_path , recipe_file_path = download_model_and_recipe (root_dir )
140+
135141 print ("Load pretrained model" )
136142 base_model = tf .keras .models .load_model (model_file_path )
137143 base_model .summary ()
0 commit comments