1111import typing
1212from collections import defaultdict
1313from dataclasses import dataclass
14- from typing import cast , DefaultDict , Iterable , Optional , Sequence
14+ from typing import Callable , cast , DefaultDict , Iterable , Optional , Sequence , TypeAlias
1515
1616import torch
1717import torch .fx
@@ -573,23 +573,34 @@ def compute_slice_and_select_loc_constraints(
573573 graph_module .recompile ()
574574
575575
576+ ConstraintsGenPass : TypeAlias = Callable [
577+ [MemConstraints ],
578+ Callable [[torch .fx .GraphModule ], Optional [PassResult ]],
579+ ]
580+
581+
576582# The class to generate all the constraints that will be passed on to the memory
577583# planning algorithm.
578584class GenerateMemConstraints :
579585 def __init__ (
580586 self ,
581587 mem_constraints : MemConstraints ,
582- additional_constraint_gen_passes : list | None = None ,
588+ additional_constraint_gen_passes : Sequence [ ConstraintsGenPass ] | None = None ,
583589 ) -> None :
584- self .mem_constraints = mem_constraints
585- self .additional_constraint_gen_passes = additional_constraint_gen_passes or []
590+ self .mem_constraints : MemConstraints = mem_constraints
591+ self .additional_constraint_gen_passes : Sequence [ConstraintsGenPass ] = (
592+ additional_constraint_gen_passes or []
593+ )
586594
587595 def __call__ (self , graph_module : torch .fx .GraphModule ) -> PassResult :
588- constraint_gen_passes : list = [
589- GenerateMemoryViewConstraints ,
590- GenerateSliceAndSelectNopConstraints ,
591- GenerateCatNopConstraints ,
592- ] + self .additional_constraint_gen_passes
596+ constraint_gen_passes : Sequence [ConstraintsGenPass ] = cast (
597+ list [ConstraintsGenPass ],
598+ [
599+ GenerateMemoryViewConstraints ,
600+ GenerateSliceAndSelectNopConstraints ,
601+ GenerateCatNopConstraints ,
602+ ],
603+ ) + list (self .additional_constraint_gen_passes )
593604 # Create a filter using the opt level in mem_constraints, and filter
594605 # the relevant passes.
595606 pass_filter = create_cadence_pass_filter (self .mem_constraints .opt_level )
@@ -602,6 +613,7 @@ def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult:
602613 typing .Callable [[torch .fx .GraphModule ], Optional [PassResult ]],
603614 ]
604615 ],
616+ # pyre-ignore[6]: Incompatible parameter type.
605617 list (filter (pass_filter , constraint_gen_passes )),
606618 )
607619 ]
0 commit comments