Skip to content

Commit ef21739

Browse files
Arm backend: Reduce complexity of get_model_and_inputs_from_name (pytorch#15247)
cc @freddan80 @per @zingo @oscarandersson8218 @digantdesai Signed-off-by: Sebastian Larsson <[email protected]>
1 parent 44511ba commit ef21739

File tree

1 file changed

+150
-73
lines changed

1 file changed

+150
-73
lines changed

examples/arm/aot_arm_compiler.py

Lines changed: 150 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -71,76 +71,151 @@
7171
logging.basicConfig(level=logging.WARNING, format=FORMAT)
7272

7373

74-
def get_model_and_inputs_from_name(
75-
model_name: str, model_input: str | None
76-
) -> Tuple[torch.nn.Module, Any]:
77-
"""Given the name of an example pytorch model, return it and example inputs.
74+
def _load_example_inputs(model_input: str | None) -> Any:
75+
"""Load example inputs from a `.pt` file when a path is provided."""
76+
if model_input is None:
77+
return None
78+
79+
logging.info(f"Load model input from {model_input}")
80+
81+
if model_input.endswith(".pt"):
82+
return torch.load(model_input, weights_only=False)
83+
84+
raise RuntimeError(
85+
f"Model input data '{model_input}' is not a valid name. Use --model_input "
86+
"<FILE>.pt e.g. saved with torch.save()"
87+
)
88+
89+
90+
def _load_internal_model(
91+
model_name: str, example_inputs: Any
92+
) -> Optional[Tuple[torch.nn.Module, Any]]:
93+
"""Load a bundled example model from the internal `MODELS` mapping."""
94+
if model_name not in MODELS:
95+
return None
96+
97+
logging.info(f"Internal model {model_name}")
98+
99+
model = MODELS[model_name]()
100+
inputs = (
101+
example_inputs
102+
if example_inputs is not None
103+
else MODELS[model_name].example_input
104+
)
105+
106+
return model, inputs
107+
108+
109+
def _load_registered_model(
110+
model_name: str, example_inputs: Any
111+
) -> Optional[Tuple[torch.nn.Module, Any]]:
112+
"""Load a registered example model from `examples.models`."""
113+
if model_name not in MODEL_NAME_TO_MODEL:
114+
return None
115+
116+
logging.warning(
117+
"Using a model from examples/models not all of these are currently supported"
118+
)
119+
logging.info(
120+
f"Load {model_name} -> {MODEL_NAME_TO_MODEL[model_name]} from examples/models"
121+
)
122+
123+
model, tmp_example_inputs, _, _ = EagerModelFactory.create_model(
124+
*MODEL_NAME_TO_MODEL[model_name]
125+
)
126+
inputs = example_inputs if example_inputs is not None else tmp_example_inputs
127+
128+
return model, inputs
129+
130+
131+
def _load_python_module_model(
132+
model_name: str, example_inputs: Any
133+
) -> Optional[Tuple[torch.nn.Module, Any]]:
134+
"""Load a model and inputs from a Python source file.
135+
136+
The file must define `ModelUnderTest` and `ModelInputs` attributes.
78137
79-
Raises RuntimeError if there is no example model corresponding to the given name.
80138
"""
81-
example_inputs = None
82-
if model_input is not None:
83-
logging.info(f"Load model input from {model_input}")
84-
if model_input.endswith(".pt"):
85-
example_inputs = torch.load(model_input, weights_only=False)
86-
else:
87-
raise RuntimeError(
88-
f"Model input data '{model_input}' is not a valid name. Use --model_input <FILE>.pt e.g. saved with torch.save()"
89-
)
139+
if not model_name.endswith(".py"):
140+
return None
90141

91-
# Case 1: Model is defined in this file
92-
if model_name in models.keys():
93-
logging.info(f"Internal model {model_name}")
94-
model = models[model_name]()
95-
if example_inputs is None:
96-
example_inputs = models[model_name].example_input
97-
# Case 2: Model is defined in examples/models/
98-
elif model_name in MODEL_NAME_TO_MODEL.keys():
99-
logging.warning(
100-
"Using a model from examples/models not all of these are currently supported"
101-
)
102-
logging.info(
103-
f"Load {model_name} -> {MODEL_NAME_TO_MODEL[model_name]} from examples/models"
104-
)
142+
logging.info(
143+
f"Load model file {model_name} "
144+
"Variable ModelUnderTest=<Model> ModelInputs=<ModelInput>"
145+
)
105146

106-
model, tmp_example_inputs, _, _ = EagerModelFactory.create_model(
107-
*MODEL_NAME_TO_MODEL[model_name]
108-
)
109-
if example_inputs is None:
110-
example_inputs = tmp_example_inputs
111-
# Case 3: Model is in an external python file loaded as a module.
112-
# ModelUnderTest should be a torch.nn.module instance
113-
# ModelInputs should be a tuple of inputs to the forward function
114-
elif model_name.endswith(".py"):
115-
logging.info(
116-
f"Load model file {model_name} Variable ModelUnderTest=<Model> ModelInputs=<ModelInput>"
117-
)
118-
import importlib.util
119-
120-
# load model's module and add it
121-
spec = importlib.util.spec_from_file_location("tmp_model", model_name)
122-
module = importlib.util.module_from_spec(spec)
123-
spec.loader.exec_module(module)
124-
model = module.ModelUnderTest
125-
if example_inputs is None:
126-
example_inputs = module.ModelInputs
127-
# Case 4: Model is in an saved model file torch.save(model)
128-
elif model_name.endswith(".pth") or model_name.endswith(".pt"):
129-
logging.info(f"Load model file {model_name}")
130-
model = torch.load(model_name, weights_only=False)
131-
if example_inputs is None:
132-
raise RuntimeError(
133-
f"Model '{model_name}' requires input data specify --model_input <FILE>.pt"
134-
)
135-
else:
147+
import importlib.util
148+
149+
spec = importlib.util.spec_from_file_location("tmp_model", model_name)
150+
if spec is None or spec.loader is None:
151+
raise RuntimeError(f"Unable to load model file {model_name}")
152+
module = importlib.util.module_from_spec(spec)
153+
spec.loader.exec_module(module)
154+
model = module.ModelUnderTest
155+
inputs = example_inputs if example_inputs is not None else module.ModelInputs
156+
157+
return model, inputs
158+
159+
160+
def _load_serialized_model(
161+
model_name: str, example_inputs: Any
162+
) -> Optional[Tuple[torch.nn.Module, Any]]:
163+
"""Load a serialized Torch model saved via `torch.save`."""
164+
if not model_name.endswith((".pth", ".pt")):
165+
return None
166+
167+
logging.info(f"Load model file {model_name}")
168+
169+
model = torch.load(model_name, weights_only=False)
170+
if example_inputs is None:
136171
raise RuntimeError(
137-
f"Model '{model_name}' is not a valid name. Use --help for a list of available models."
172+
f"Model '{model_name}' requires input data specify --model_input <FILE>.pt"
138173
)
139-
logging.debug(f"Loaded model: {model}")
140-
logging.debug(f"Loaded input: {example_inputs}")
174+
141175
return model, example_inputs
142176

143177

178+
def get_model_and_inputs_from_name(
179+
model_name: str, model_input: str | None
180+
) -> Tuple[torch.nn.Module, Any]:
181+
"""Resolve a model name into a model instance and example inputs.
182+
183+
Args:
184+
model_name: Identifier for the model. It can be a key in
185+
`MODEL_NAME_TO_MODEL`, a Python module path, or a serialized
186+
model file path.
187+
model_input: Optional path to a `.pt` file containing example inputs.
188+
189+
Returns:
190+
Tuple of `(model, example_inputs)` ready for compilation.
191+
192+
Raises:
193+
RuntimeError: If the model cannot be resolved or required inputs are
194+
missing.
195+
196+
"""
197+
example_inputs = _load_example_inputs(model_input)
198+
199+
loaders = (
200+
_load_internal_model,
201+
_load_registered_model,
202+
_load_python_module_model,
203+
_load_serialized_model,
204+
)
205+
206+
for loader in loaders:
207+
result = loader(model_name, example_inputs)
208+
if result is not None:
209+
model, example_inputs = result
210+
logging.debug(f"Loaded model: {model}")
211+
logging.debug(f"Loaded input: {example_inputs}")
212+
return model, example_inputs
213+
214+
raise RuntimeError(
215+
f"Model '{model_name}' is not a valid name. Use --help for a list of available models."
216+
)
217+
218+
144219
def quantize(
145220
model: GraphModule,
146221
model_name: str,
@@ -150,7 +225,9 @@ def quantize(
150225
evaluator_config: Dict[str, Any] | None,
151226
) -> GraphModule:
152227
"""This is the official recommended flow for quantization in pytorch 2.0
153-
export"""
228+
export.
229+
230+
"""
154231
logging.info("Quantizing Model...")
155232
logging.debug(f"Original model: {model}")
156233

@@ -238,7 +315,7 @@ def forward(self, x):
238315
can_delegate = True
239316

240317

241-
models = {
318+
MODELS = {
242319
"qadd": QuantAddTest,
243320
"qadd2": QuantAddTest2,
244321
"qops": QuantOpTest,
@@ -247,7 +324,7 @@ def forward(self, x):
247324
"qlinear": QuantLinearTest,
248325
}
249326

250-
calibration_data = {
327+
CALIBRATION_DATA = {
251328
"qadd": (torch.randn(32, 2, 1),),
252329
"qadd2": (
253330
torch.randn(32, 2, 1),
@@ -261,7 +338,7 @@ def forward(self, x):
261338
),
262339
}
263340

264-
targets = [
341+
TARGETS = [
265342
"ethos-u55-32",
266343
"ethos-u55-64",
267344
"ethos-u55-128",
@@ -289,10 +366,10 @@ def get_calibration_data(
289366
if evaluator_data is not None:
290367
return evaluator_data
291368

292-
# If the model is in the calibration_data dictionary, get the data from there
369+
# If the model is in the CALIBRATION_DATA dictionary, get the data from there
293370
# This is used for the simple model examples provided
294-
if model_name in calibration_data:
295-
return calibration_data[model_name]
371+
if model_name in CALIBRATION_DATA:
372+
return CALIBRATION_DATA[model_name]
296373

297374
# As a last resort, fallback to the scripts previous behavior and return the example inputs
298375
return example_inputs
@@ -365,7 +442,7 @@ def get_args():
365442
"-m",
366443
"--model_name",
367444
required=True,
368-
help=f"Model file .py/.pth/.pt, builtin model or a model from examples/models. Valid names: {set(list(models.keys())+list(MODEL_NAME_TO_MODEL.keys()))}",
445+
help=f"Model file .py/.pth/.pt, builtin model or a model from examples/models. Valid names: {set(list(MODELS.keys()) + list(MODEL_NAME_TO_MODEL.keys()))}",
369446
)
370447
parser.add_argument(
371448
"--model_input",
@@ -401,8 +478,8 @@ def get_args():
401478
action="store",
402479
required=False,
403480
default="ethos-u55-128",
404-
choices=targets,
405-
help=f"For ArmBackend delegated models, pick the target, and therefore the instruction set generated. valid targets are {targets}",
481+
choices=TARGETS,
482+
help=f"For ArmBackend delegated models, pick the target, and therefore the instruction set generated. valid targets are {TARGETS}",
406483
)
407484
parser.add_argument(
408485
"-e",
@@ -506,9 +583,9 @@ def get_args():
506583
torch.ops.load_library(args.so_library)
507584

508585
if (
509-
args.model_name in models.keys()
586+
args.model_name in MODELS.keys()
510587
and args.delegate is True
511-
and models[args.model_name].can_delegate is False
588+
and MODELS[args.model_name].can_delegate is False
512589
):
513590
raise RuntimeError(f"Model {args.model_name} cannot be delegated.")
514591

0 commit comments

Comments
 (0)