Skip to content

Commit 80bf883

Browse files
swolchokpytorchmergebot
authored andcommitted
Replace manual cache in _python_dispatch.get_alias_info with functools.cache (pytorch#161286)
In addition to being more code, the manual cache was doing an extra dictionary lookup on each cache hit. Pull Request resolved: pytorch#161286 Approved by: https://github.com/wconstab
1 parent 9de9d25 commit 80bf883

File tree

1 file changed

+3
-8
lines changed

1 file changed

+3
-8
lines changed

torch/utils/_python_dispatch.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
# mypy: allow-untyped-defs
22
import contextlib
3+
import functools
34
import warnings
45
from collections import deque
56
from collections.abc import Sequence
67
from dataclasses import dataclass
7-
from typing import Any, Optional, overload, Protocol, Union
8+
from typing import Optional, overload, Protocol, Union
89
from typing_extensions import TypeIs
910

1011
import torch
@@ -527,15 +528,10 @@ class SchemaInfo:
527528
outs: list[AliasInfo]
528529

529530

530-
# Can't import torch._ops.OpOverload due to circular reference
531-
parsed_schema_map: dict[Any, SchemaInfo] = {}
532-
533-
534531
# Given an OpOverload, returns schema information on it.
535532
# This is cached for efficiency, since it can involve running torchgen
533+
@functools.cache
536534
def get_alias_info(func) -> SchemaInfo:
537-
if func in parsed_schema_map:
538-
return parsed_schema_map[func]
539535
# For ATen ops: use torchgen (since torchscript parser doesn't handle alias annotations
540536
# properly for some ops that output tensorlists)
541537
if func.namespace == "aten":
@@ -598,7 +594,6 @@ def get_alias_info(func) -> SchemaInfo:
598594
for a in func._schema.returns
599595
]
600596
schema_info = SchemaInfo(args=arg_schemas, outs=out_schemas)
601-
parsed_schema_map[func] = schema_info
602597
return schema_info
603598

604599

0 commit comments

Comments
 (0)