Skip to content

Commit 70a4d98

Browse files
mergennachinfacebook-github-bot
authored andcommitted
Fix Pyre problems in executorch (pytorch#3955)
Summary: Pull Request resolved: pytorch#3955 Current problems: https://www.internalfb.com/phabricator/paste/view/P1410619602 We will eventually migrate from pyre to mypy at some point. But in the meantime, here's a fix. Reviewed By: tarun292 Differential Revision: D58473224 fbshipit-source-id: a2e789b6c324c98d64c26f5ad5103de60dce513d
1 parent 189d548 commit 70a4d98

File tree

12 files changed

+15
-27
lines changed

12 files changed

+15
-27
lines changed

backends/xnnpack/test/models/inception_v3.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414

1515
class TestInceptionV3(unittest.TestCase):
16-
# pyre-ignore
1716
ic3 = models.inception_v3(weights="IMAGENET1K_V1").eval() # noqa
1817
model_inputs = (torch.randn(1, 3, 224, 224),)
1918

examples/models/inception_v3/model.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
import torch
1010

11-
# pyre-ignore
1211
from torchvision.models import inception_v3 # @manual
1312

1413
from ..model_base import EagerModelBase
@@ -20,7 +19,6 @@ def __init__(self):
2019

2120
def get_eager_model(self) -> torch.nn.Module:
2221
logging.info("Loading torchvision inception_v3 model")
23-
# pyre-ignore
2422
inception_v3_model = inception_v3(weights="IMAGENET1K_V1")
2523
logging.info("Loaded torchvision inception_v3 model")
2624
return inception_v3_model

examples/models/llama2/lib/partitioner_lib.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,12 @@ def get_coreml_partitioner(args):
5757
), "CoreML backend currently only supports static shape and use_kv_cache=True is the only way to support it at the moment"
5858
try:
5959
import coremltools as ct
60-
from executorch.backends.apple.coreml.compiler import CoreMLBackend
61-
from executorch.backends.apple.coreml.partition import CoreMLPartitioner
60+
from executorch.backends.apple.coreml.compiler import ( # pyre-ignore
61+
CoreMLBackend,
62+
)
63+
from executorch.backends.apple.coreml.partition import ( # pyre-ignore
64+
CoreMLPartitioner,
65+
)
6266
except ImportError:
6367
raise ImportError(
6468
"Please install the CoreML backend follwing https://pytorch.org/executorch/main/build-run-coreml.html"

examples/models/llama2/source_transformation/quantize.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,11 @@
1717
from ..builder import DType
1818

1919
try:
20-
# pyre-ignore[21]: Undefined import.
2120
from fairseq2.nn.embedding import (
2221
Embedding as fsEmbedding,
2322
StandardEmbedding as fsStandardEmbedding,
2423
)
2524

26-
# pyre-ignore[21]: Undefined import.
2725
from fairseq2.nn.projection import Linear as fsLinear
2826

2927
print("Using fairseq2 modules.")
@@ -98,10 +96,9 @@ def quantize(
9896

9997
try:
10098
# torchao 0.3+
101-
# pyre-ignore
10299
from torchao._eval import InputRecorder
103100
except ImportError:
104-
from torchao.quantization.GPTQ import InputRecorder
101+
from torchao.quantization.GPTQ import InputRecorder # pyre-ignore
105102

106103
from torchao.quantization.quant_api import Int8DynActInt4WeightGPTQQuantizer
107104

@@ -113,7 +110,7 @@ def quantize(
113110
)
114111

115112
inputs = (
116-
InputRecorder( # pyre-ignore
113+
InputRecorder(
117114
tokenizer,
118115
calibration_seq_length,
119116
None, # input_prep_func

examples/models/mobilenet_v2/model.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
import torch
1010

11-
# pyre-ignore
1211
from torchvision.models import mobilenet_v2 # @manual
1312
from torchvision.models.mobilenetv2 import MobileNet_V2_Weights
1413

@@ -21,7 +20,6 @@ def __init__(self):
2120

2221
def get_eager_model(self) -> torch.nn.Module:
2322
logging.info("Loading mobilenet_v2 model")
24-
# pyre-ignore
2523
mv2 = mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT)
2624
logging.info("Loaded mobilenet_v2 model")
2725
return mv2
@@ -36,7 +34,6 @@ def __init__(self):
3634
pass
3735

3836
def get_eager_model(self) -> torch.nn.Module:
39-
# pyre-ignore
4037
mv2 = mobilenet_v2()
4138
return mv2
4239

examples/models/mobilenet_v3/model.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,7 @@ def __init__(self):
1919

2020
def get_eager_model(self) -> torch.nn.Module:
2121
logging.info("Loading mobilenet_v3 model")
22-
# pyre-ignore
2322
mv3_small = models.mobilenet_v3_small(
24-
# pyre-ignore[16]
2523
weights=models.MobileNet_V3_Small_Weights.IMAGENET1K_V1
2624
)
2725
logging.info("Loaded mobilenet_v3 model")

examples/models/resnet/model.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
import torch
1010

11-
# pyre-ignore
1211
from torchvision.models import ( # @manual
1312
resnet18,
1413
ResNet18_Weights,
@@ -25,7 +24,6 @@ def __init__(self):
2524

2625
def get_eager_model(self) -> torch.nn.Module:
2726
logging.info("Loading torchvision resnet18 model")
28-
# pyre-ignore
2927
resnet18_model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
3028
logging.info("Loaded torchvision resnet18 model")
3129
return resnet18_model
@@ -41,7 +39,6 @@ def __init__(self):
4139

4240
def get_eager_model(self) -> torch.nn.Module:
4341
logging.info("Loading torchvision resnet50 model")
44-
# pyre-ignore
4542
resnet50_model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
4643
logging.info("Loaded torchvision resnet50 model")
4744
return resnet50_model

examples/models/torchvision_vit/model.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
import torch
1010

11-
# pyre-ignore
1211
from torchvision.models import vit_b_16 # @manual
1312

1413
from ..model_base import EagerModelBase
@@ -20,7 +19,6 @@ def __init__(self):
2019

2120
def get_eager_model(self) -> torch.nn.Module:
2221
logging.info("Loading torchvision vit_b_16 model")
23-
# pyre-ignore
2422
vit_b_16_model = vit_b_16(weights="IMAGENET1K_V1")
2523
logging.info("Loaded torchvision vit_b_16 model")
2624
return vit_b_16_model

exir/serde/export_serialize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1661,7 +1661,7 @@ def deserialize_argument_spec(self, x: Argument) -> ep.ArgumentSpec:
16611661
elif x.type == "as_sym_int":
16621662
return PySymIntArgument(name=x.as_sym_int.as_name)
16631663
else:
1664-
return PyConstantArgument(value=self.deserialize_input(x))
1664+
return PyConstantArgument(name="", value=self.deserialize_input(x))
16651665

16661666
def deserialize_module_call_signature(
16671667
self, module_call_signature: ModuleCallSignature

exir/serde/serialize.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -756,7 +756,7 @@ def deserialize(
756756

757757
def save(
758758
ep_save: ep.ExportedProgram,
759-
f: Union[str, os.PathLike, io.BytesIO],
759+
f: Union[str, os.PathLike[str], io.BytesIO],
760760
*,
761761
extra_files: Optional[Dict[str, Any]] = None,
762762
opset_version: Optional[Dict[str, int]] = None,
@@ -767,7 +767,7 @@ def save(
767767
artifact: export_serialize.SerializedArtifact = serialize(ep_save, opset_version)
768768

769769
if isinstance(f, (str, os.PathLike)):
770-
f = os.fspath(f)
770+
f = os.fspath(str(f))
771771

772772
with zipfile.ZipFile(f, "w") as zipf:
773773
# Save every field in the SerializedArtifact to a file.
@@ -786,13 +786,13 @@ def save(
786786

787787

788788
def load(
789-
f: Union[str, os.PathLike, io.BytesIO],
789+
f: Union[str, os.PathLike[str], io.BytesIO],
790790
*,
791791
extra_files: Optional[Dict[str, Any]] = None,
792792
expected_opset_version: Optional[Dict[str, int]] = None,
793793
) -> ep.ExportedProgram:
794794
if isinstance(f, (str, os.PathLike)):
795-
f = os.fspath(f)
795+
f = os.fspath(str(f))
796796

797797
extra_files = extra_files or {}
798798

0 commit comments

Comments
 (0)