Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 0aec132

Browse files
bfinerannatuan
andauthored
Retrain ResNet20 model to work with older TF versions (#127) (#128)
* Retrain model to work with older TF versions * Load pruned recipe * Using recipe stub Co-authored-by: Tuan Nguyen <[email protected]>
1 parent dde3373 commit 0aec132

File tree

1 file changed

+14
-8
lines changed

1 file changed

+14
-8
lines changed

integrations/keras/prune_resnet20.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,9 @@ def download_model_and_recipe(root_dir: str):
6565
Download pretrained model and a pruning recipe
6666
"""
6767
model_dir = os.path.join(root_dir, "resnet20_v1")
68-
zoo_model = Zoo.load_model(
68+
69+
# Load base model to prune
70+
base_zoo_model = Zoo.load_model(
6971
domain="cv",
7072
sub_domain="classification",
7173
architecture="resnet_v1",
@@ -74,18 +76,21 @@ def download_model_and_recipe(root_dir: str):
7476
repo="sparseml",
7577
dataset="cifar_10",
7678
training_scheme=None,
77-
optim_name="pruned",
78-
optim_category="conservative",
79+
optim_name="base",
80+
optim_category=None,
7981
optim_target=None,
8082
override_parent_path=model_dir,
8183
)
82-
zoo_model.download()
83-
model_file_path = zoo_model.framework_files[0].downloaded_path()
84+
base_zoo_model.download()
85+
model_file_path = base_zoo_model.framework_files[0].downloaded_path()
8486
if not os.path.exists(model_file_path) or not model_file_path.endswith(".h5"):
8587
raise RuntimeError("Model file not found: {}".format(model_file_path))
86-
recipe_file_path = zoo_model.recipes[0].downloaded_path()
87-
if not os.path.exists(recipe_file_path):
88-
raise RuntimeError("Recipe file not found: {}".format(recipe_file_path))
88+
89+
# Simply use the recipe stub
90+
recipe_file_path = (
91+
"zoo:cv/classification/resnet_v1-20/keras/sparseml/cifar_10/pruned-conservative"
92+
)
93+
8994
return model_file_path, recipe_file_path
9095

9196

@@ -132,6 +137,7 @@ def main():
132137
(X_train, y_train), (X_test, y_test) = load_and_normalize_cifar10()
133138

134139
model_file_path, recipe_file_path = download_model_and_recipe(root_dir)
140+
135141
print("Load pretrained model")
136142
base_model = tf.keras.models.load_model(model_file_path)
137143
base_model.summary()

0 commit comments

Comments
 (0)