|
1 | 1 | # mypy: allow-untyped-defs
|
2 | 2 | import contextlib
|
| 3 | +import functools |
3 | 4 | import warnings
|
4 | 5 | from collections import deque
|
5 | 6 | from collections.abc import Sequence
|
6 | 7 | from dataclasses import dataclass
|
7 |
| -from typing import Any, Optional, overload, Protocol, Union |
| 8 | +from typing import Optional, overload, Protocol, Union |
8 | 9 | from typing_extensions import TypeIs
|
9 | 10 |
|
10 | 11 | import torch
|
@@ -527,15 +528,10 @@ class SchemaInfo:
|
527 | 528 | outs: list[AliasInfo]
|
528 | 529 |
|
529 | 530 |
|
530 |
| -# Can't import torch._ops.OpOverload due to circular reference |
531 |
| -parsed_schema_map: dict[Any, SchemaInfo] = {} |
532 |
| - |
533 |
| - |
534 | 531 | # Given an OpOverload, returns schema information on it.
|
535 | 532 | # This is cached for efficiency, since it can involve running torchgen
|
| 533 | +@functools.cache |
536 | 534 | def get_alias_info(func) -> SchemaInfo:
|
537 |
| - if func in parsed_schema_map: |
538 |
| - return parsed_schema_map[func] |
539 | 535 | # For ATen ops: use torchgen (since torchscript parser doesn't handle alias annotations
|
540 | 536 | # properly for some ops that output tensorlists)
|
541 | 537 | if func.namespace == "aten":
|
@@ -598,7 +594,6 @@ def get_alias_info(func) -> SchemaInfo:
|
598 | 594 | for a in func._schema.returns
|
599 | 595 | ]
|
600 | 596 | schema_info = SchemaInfo(args=arg_schemas, outs=out_schemas)
|
601 |
| - parsed_schema_map[func] = schema_info |
602 | 597 | return schema_info
|
603 | 598 |
|
604 | 599 |
|
|
0 commit comments