Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 9e013da

Browse files
authored
correct pytorch model sub-architecture names (#57)
1 parent 94f3d42 commit 9e013da

File tree

6 files changed

+14
-9
lines changed

6 files changed

+14
-9
lines changed

src/sparseml/pytorch/models/classification/inception_v3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,7 @@ def forward(self, x_tens: Tensor) -> Tuple[Tensor, ...]:
470470
domain="cv",
471471
sub_domain="classification",
472472
architecture="inception_v3",
473-
sub_architecture="none",
473+
sub_architecture=None,
474474
default_dataset="imagenet",
475475
default_desc="base",
476476
def_ignore_error_tensors=[

src/sparseml/pytorch/models/classification/resnet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -825,7 +825,7 @@ def resnetv2_50(num_classes: int = 1000, class_type: str = "single") -> ResNet:
825825
domain="cv",
826826
sub_domain="classification",
827827
architecture="resnet_v1",
828-
sub_architecture="50-2xwidth",
828+
sub_architecture="50_2x",
829829
default_dataset="imagenet",
830830
default_desc="base",
831831
def_ignore_error_tensors=["classifier.fc.weight", "classifier.fc.bias"],
@@ -1067,7 +1067,7 @@ def resnetv2_101(num_classes: int = 1000, class_type: str = "single") -> ResNet:
10671067
domain="cv",
10681068
sub_domain="classification",
10691069
architecture="resnet_v1",
1070-
sub_architecture="101-2xwidth",
1070+
sub_architecture="101_2x",
10711071
default_dataset="imagenet",
10721072
default_desc="base",
10731073
def_ignore_error_tensors=["classifier.fc.weight", "classifier.fc.bias"],

src/sparseml/pytorch/models/classification/vgg.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ def vgg11(num_classes: int = 1000, class_type: str = "single") -> VGG:
238238
domain="cv",
239239
sub_domain="classification",
240240
architecture="vgg",
241-
sub_architecture="11-bn",
241+
sub_architecture="11_bn",
242242
default_dataset="imagenet",
243243
default_desc="base",
244244
def_ignore_error_tensors=["classifier.mlp.6.weight", "classifier.mlp.6.bias"],
@@ -324,7 +324,7 @@ def vgg13(num_classes: int = 1000, class_type: str = "single") -> VGG:
324324
domain="cv",
325325
sub_domain="classification",
326326
architecture="vgg",
327-
sub_architecture="13-bn",
327+
sub_architecture="13_bn",
328328
default_dataset="imagenet",
329329
default_desc="base",
330330
def_ignore_error_tensors=["classifier.mlp.6.weight", "classifier.mlp.6.bias"],
@@ -410,7 +410,7 @@ def vgg16(num_classes: int = 1000, class_type: str = "single") -> VGG:
410410
domain="cv",
411411
sub_domain="classification",
412412
architecture="vgg",
413-
sub_architecture="16-bn",
413+
sub_architecture="16_bn",
414414
default_dataset="imagenet",
415415
default_desc="base",
416416
def_ignore_error_tensors=["classifier.mlp.6.weight", "classifier.mlp.6.bias"],
@@ -496,7 +496,7 @@ def vgg19(num_classes: int = 1000, class_type: str = "single") -> VGG:
496496
domain="cv",
497497
sub_domain="classification",
498498
architecture="vgg",
499-
sub_architecture="19-bn",
499+
sub_architecture="19_bn",
500500
default_dataset="imagenet",
501501
default_desc="base",
502502
def_ignore_error_tensors=["classifier.mlp.6.weight", "classifier.mlp.6.bias"],

src/sparseml/pytorch/models/detection/yolo_v3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ def forward(self, inp: Tensor):
276276
domain="cv",
277277
sub_domain="detection",
278278
architecture="yolo_v3",
279-
sub_architecture="none",
279+
sub_architecture="spp",
280280
default_dataset="coco",
281281
default_desc="base",
282282
)

src/sparseml/tensorflow_v1/optim/mask_pruning.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@
2020
from collections import namedtuple
2121
from typing import List, Tuple
2222

23+
2324
try:
2425
import tensorflow.contrib.graph_editor as graph_editor
26+
2527
tf_contrib_err = None
2628
except Exception as err:
2729
graph_editor = None

src/sparseml/tensorflow_v1/utils/variable.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import numpy
1919

20+
2021
try:
2122
import tensorflow.contrib.graph_editor as graph_editor
2223
from tensorflow.contrib.graph_editor.util import ListView
@@ -239,7 +240,9 @@ def get_ops_and_inputs_by_name_or_regex(
239240
nm_ks_consuming_ops_with_input = [
240241
(consuming_op, inp)
241242
for output_tens in graph_editor.sgv(op).outputs
242-
for consuming_op in graph_editor.get_consuming_ops(output_tens)
243+
for consuming_op in graph_editor.get_consuming_ops(
244+
output_tens
245+
)
243246
if "_nm_ks" not in consuming_op.name
244247
]
245248
prunable_ops_and_inputs += nm_ks_consuming_ops_with_input

0 commit comments

Comments
 (0)