|  | 
| 1 | 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. | 
| 2 | 2 | # All rights reserved. | 
|  | 3 | +# Copyright 2025 Arm Limited and/or its affiliates. | 
| 3 | 4 | # | 
| 4 | 5 | # This source code is licensed under the BSD-style license found in the | 
| 5 | 6 | # LICENSE file in the root directory of this source tree | 
| 6 | 7 | 
 | 
|  | 8 | +# | 
|  | 9 | +# This source code is licensed under the BSD-style license found in the | 
|  | 10 | +# LICENSE file in the root directory of this source tree. | 
|  | 11 | + | 
| 7 | 12 | import logging | 
| 8 |  | -from typing import Any, Dict, List, Optional, Union | 
|  | 13 | +from typing import Any, Dict, List, Optional, Sequence, Union | 
| 9 | 14 | 
 | 
| 10 | 15 | import numpy as np | 
| 11 | 16 | 
 | 
| 12 | 17 | import torch | 
|  | 18 | +import torch.fx as fx | 
| 13 | 19 | 
 | 
| 14 | 20 | from executorch.exir import EdgeProgramManager, ExportedProgram | 
| 15 | 21 | from executorch.exir.dialects._ops import ops as exir_ops | 
| @@ -316,3 +322,93 @@ def call(self, graph_module: torch.fx.GraphModule): | 
| 316 | 322 |             self.edge_manager_update_quant_config_method(i, self.dequant_args[i]) | 
| 317 | 323 | 
 | 
| 318 | 324 |         return PassResult(graph_module, True) | 
|  | 325 | + | 
|  | 326 | + | 
|  | 327 | +def extract_io_quant_params( | 
|  | 328 | +    edge_prog: EdgeProgramManager, | 
|  | 329 | +    *, | 
|  | 330 | +    input_idxs: Sequence[int] = (0,), | 
|  | 331 | +    output_idxs: Sequence[int] = (0,), | 
|  | 332 | +) -> Dict[str, Dict[str, Dict[str, Any]]]: | 
|  | 333 | +    """ | 
|  | 334 | +    Returns quantization parameters such as scale/zero_point: | 
|  | 335 | +      { | 
|  | 336 | +        "inputs": { | 
|  | 337 | +          <placeholder_name>: {"scale": float, "zero_point": int} | 
|  | 338 | +        }, | 
|  | 339 | +        "outputs": { | 
|  | 340 | +          <node_name>: {"scale": float, "zero_point": int} | 
|  | 341 | +        } | 
|  | 342 | +      } | 
|  | 343 | +
 | 
|  | 344 | +    Note that this function will strip out the IO quantize/dequantize ops as | 
|  | 345 | +    it records their parameters, so if you need to preserve the original graph | 
|  | 346 | +    you need to make a copy with copy.deepcopy before. | 
|  | 347 | +
 | 
|  | 348 | +    Note that `to_edge_transform_and_lower` should be called before. | 
|  | 349 | +    """ | 
|  | 350 | +    # Use IO passes | 
|  | 351 | +    passes = [] | 
|  | 352 | +    for idx in input_idxs: | 
|  | 353 | +        passes.append(QuantizeInputs(edge_prog, [idx])) | 
|  | 354 | +    for idx in output_idxs: | 
|  | 355 | +        passes.append(QuantizeOutputs(edge_prog, [idx])) | 
|  | 356 | + | 
|  | 357 | +    # Apply them | 
|  | 358 | +    edge_prog = edge_prog.transform(passes) | 
|  | 359 | + | 
|  | 360 | +    cfg = getattr(edge_prog, "_config_methods", {}) or {} | 
|  | 361 | + | 
|  | 362 | +    # We need GraphModule to find node names | 
|  | 363 | +    gm = edge_prog.exported_program().graph_module | 
|  | 364 | + | 
|  | 365 | +    input_names = _gather_io_names(gm, side="input") | 
|  | 366 | +    output_names = _gather_io_names(gm, side="output") | 
|  | 367 | + | 
|  | 368 | +    # Build the result dict | 
|  | 369 | +    result = {"inputs": {}, "outputs": {}} | 
|  | 370 | +    for key, val in cfg.items(): | 
|  | 371 | +        if key.startswith("input"): | 
|  | 372 | +            prefix, section, names = "input", "inputs", input_names | 
|  | 373 | +        elif key.startswith("output"): | 
|  | 374 | +            prefix, section, names = "output", "outputs", output_names | 
|  | 375 | +        else: | 
|  | 376 | +            continue | 
|  | 377 | + | 
|  | 378 | +        idx_str, param = key[len(prefix) :].split("_", 1) | 
|  | 379 | +        idx = int(idx_str) | 
|  | 380 | +        name = names[idx] | 
|  | 381 | +        # We need to map 'zp' to 'zero_point' | 
|  | 382 | +        out_param = "zero_point" if param in ("zp", "zero_point") else param | 
|  | 383 | +        result[section].setdefault(name, {})[out_param] = val | 
|  | 384 | + | 
|  | 385 | +    return result | 
|  | 386 | + | 
|  | 387 | + | 
|  | 388 | +def _gather_io_names(gm: fx.GraphModule, side: str): | 
|  | 389 | +    """ | 
|  | 390 | +    For 'input', returns placeholder names in graph order. | 
|  | 391 | +    For 'output', returns names of output nodes. | 
|  | 392 | +    """ | 
|  | 393 | +    if side == "input": | 
|  | 394 | +        return [n.name for n in gm.graph.nodes if n.op == "placeholder"] | 
|  | 395 | + | 
|  | 396 | +    if side == "output": | 
|  | 397 | + | 
|  | 398 | +        def _flatten(args): | 
|  | 399 | +            out = [] | 
|  | 400 | + | 
|  | 401 | +            def rec(x): | 
|  | 402 | +                if isinstance(x, (tuple, list)): | 
|  | 403 | +                    for y in x: | 
|  | 404 | +                        rec(y) | 
|  | 405 | +                elif isinstance(x, fx.Node): | 
|  | 406 | +                    out.append(x) | 
|  | 407 | + | 
|  | 408 | +            rec(args) | 
|  | 409 | +            return out | 
|  | 410 | + | 
|  | 411 | +        output_node = next(n for n in gm.graph.nodes if n.op == "output") | 
|  | 412 | +        return [n.name for n in _flatten(output_node.args)] | 
|  | 413 | + | 
|  | 414 | +    raise ValueError(f"Unknown side: {side}") | 
0 commit comments