You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
"""Load a serialized Torch model saved via `torch.save`."""
164
+
ifnotmodel_name.endswith((".pth", ".pt")):
165
+
returnNone
166
+
167
+
logging.info(f"Load model file {model_name}")
168
+
169
+
model=torch.load(model_name, weights_only=False)
170
+
ifexample_inputsisNone:
136
171
raiseRuntimeError(
137
-
f"Model '{model_name}' is not a valid name. Use --help for a list of available models."
172
+
f"Model '{model_name}' requires input data specify --model_input <FILE>.pt"
138
173
)
139
-
logging.debug(f"Loaded model: {model}")
140
-
logging.debug(f"Loaded input: {example_inputs}")
174
+
141
175
returnmodel, example_inputs
142
176
143
177
178
+
defget_model_and_inputs_from_name(
179
+
model_name: str, model_input: str|None
180
+
) ->Tuple[torch.nn.Module, Any]:
181
+
"""Resolve a model name into a model instance and example inputs.
182
+
183
+
Args:
184
+
model_name: Identifier for the model. It can be a key in
185
+
`MODEL_NAME_TO_MODEL`, a Python module path, or a serialized
186
+
model file path.
187
+
model_input: Optional path to a `.pt` file containing example inputs.
188
+
189
+
Returns:
190
+
Tuple of `(model, example_inputs)` ready for compilation.
191
+
192
+
Raises:
193
+
RuntimeError: If the model cannot be resolved or required inputs are
194
+
missing.
195
+
196
+
"""
197
+
example_inputs=_load_example_inputs(model_input)
198
+
199
+
loaders= (
200
+
_load_internal_model,
201
+
_load_registered_model,
202
+
_load_python_module_model,
203
+
_load_serialized_model,
204
+
)
205
+
206
+
forloaderinloaders:
207
+
result=loader(model_name, example_inputs)
208
+
ifresultisnotNone:
209
+
model, example_inputs=result
210
+
logging.debug(f"Loaded model: {model}")
211
+
logging.debug(f"Loaded input: {example_inputs}")
212
+
returnmodel, example_inputs
213
+
214
+
raiseRuntimeError(
215
+
f"Model '{model_name}' is not a valid name. Use --help for a list of available models."
216
+
)
217
+
218
+
144
219
defquantize(
145
220
model: GraphModule,
146
221
model_name: str,
@@ -150,7 +225,9 @@ def quantize(
150
225
evaluator_config: Dict[str, Any] |None,
151
226
) ->GraphModule:
152
227
"""This is the official recommended flow for quantization in pytorch 2.0
153
-
export"""
228
+
export.
229
+
230
+
"""
154
231
logging.info("Quantizing Model...")
155
232
logging.debug(f"Original model: {model}")
156
233
@@ -238,7 +315,7 @@ def forward(self, x):
238
315
can_delegate=True
239
316
240
317
241
-
models= {
318
+
MODELS= {
242
319
"qadd": QuantAddTest,
243
320
"qadd2": QuantAddTest2,
244
321
"qops": QuantOpTest,
@@ -247,7 +324,7 @@ def forward(self, x):
247
324
"qlinear": QuantLinearTest,
248
325
}
249
326
250
-
calibration_data= {
327
+
CALIBRATION_DATA= {
251
328
"qadd": (torch.randn(32, 2, 1),),
252
329
"qadd2": (
253
330
torch.randn(32, 2, 1),
@@ -261,7 +338,7 @@ def forward(self, x):
261
338
),
262
339
}
263
340
264
-
targets= [
341
+
TARGETS= [
265
342
"ethos-u55-32",
266
343
"ethos-u55-64",
267
344
"ethos-u55-128",
@@ -289,10 +366,10 @@ def get_calibration_data(
289
366
ifevaluator_dataisnotNone:
290
367
returnevaluator_data
291
368
292
-
# If the model is in the calibration_data dictionary, get the data from there
369
+
# If the model is in the CALIBRATION_DATA dictionary, get the data from there
293
370
# This is used for the simple model examples provided
294
-
ifmodel_nameincalibration_data:
295
-
returncalibration_data[model_name]
371
+
ifmodel_nameinCALIBRATION_DATA:
372
+
returnCALIBRATION_DATA[model_name]
296
373
297
374
# As a last resort, fallback to the scripts previous behavior and return the example inputs
298
375
returnexample_inputs
@@ -365,7 +442,7 @@ def get_args():
365
442
"-m",
366
443
"--model_name",
367
444
required=True,
368
-
help=f"Model file .py/.pth/.pt, builtin model or a model from examples/models. Valid names: {set(list(models.keys())+list(MODEL_NAME_TO_MODEL.keys()))}",
445
+
help=f"Model file .py/.pth/.pt, builtin model or a model from examples/models. Valid names: {set(list(MODELS.keys())+list(MODEL_NAME_TO_MODEL.keys()))}",
369
446
)
370
447
parser.add_argument(
371
448
"--model_input",
@@ -401,8 +478,8 @@ def get_args():
401
478
action="store",
402
479
required=False,
403
480
default="ethos-u55-128",
404
-
choices=targets,
405
-
help=f"For ArmBackend delegated models, pick the target, and therefore the instruction set generated. valid targets are {targets}",
481
+
choices=TARGETS,
482
+
help=f"For ArmBackend delegated models, pick the target, and therefore the instruction set generated. valid targets are {TARGETS}",
406
483
)
407
484
parser.add_argument(
408
485
"-e",
@@ -506,9 +583,9 @@ def get_args():
506
583
torch.ops.load_library(args.so_library)
507
584
508
585
if (
509
-
args.model_nameinmodels.keys()
586
+
args.model_nameinMODELS.keys()
510
587
andargs.delegateisTrue
511
-
andmodels[args.model_name].can_delegateisFalse
588
+
andMODELS[args.model_name].can_delegateisFalse
512
589
):
513
590
raiseRuntimeError(f"Model {args.model_name} cannot be delegated.")
0 commit comments