44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7- # pyre-unsafe
7+ # pyre-strict
88
99import math
1010import unittest
11- from typing import cast , Optional
11+ from typing import cast , List , Optional
1212
1313import executorch .backends .cadence .aot .ops_registrations # noqa
1414import torch
@@ -224,11 +224,11 @@ def verify_nop_memory_alloc(self, graph_module: torch.fx.GraphModule) -> None:
224224 # GenerateSliceAndSelectNopConstraints, and GenerateCatNopConstraints passes.
225225 def run_memory_planning (
226226 self ,
227- original ,
228- opt_level = 2 ,
229- mem_algo = 1 , # greedy_by_size_for_offset_calculation_with_hierarchy
230- alloc_graph_input = True ,
231- alloc_graph_output = True ,
227+ original : GraphModule ,
228+ opt_level : int = 2 ,
229+ mem_algo : int = 1 , # greedy_by_size_for_offset_calculation_with_hierarchy
230+ alloc_graph_input : bool = True ,
231+ alloc_graph_output : bool = True ,
232232 memory_config : Optional [MemoryConfig ] = None ,
233233 ) -> GraphModule :
234234 if memory_config is None :
@@ -242,6 +242,7 @@ def run_memory_planning(
242242 alloc_graph_output = alloc_graph_output ,
243243 )(graph_module ).graph_module
244244
245+ # pyre-ignore[56]
245246 @parameterized .expand (
246247 [
247248 [
@@ -259,7 +260,11 @@ def run_memory_planning(
259260 ]
260261 )
261262 def test_optimize_cat_on_placeholders (
262- self , x_shape , y_shape , concat_dim , alloc_graph_input
263+ self ,
264+ x_shape : List [int ],
265+ y_shape : List [int ],
266+ concat_dim : int ,
267+ alloc_graph_input : bool ,
263268 ) -> None :
264269 concat_shape = [x_shape [concat_dim ] + y_shape [concat_dim ], x_shape [1 ]]
265270 builder = GraphBuilder ()
@@ -294,7 +299,12 @@ def test_optimize_cat_on_placeholders(
294299 # "add_add_cat_model" : cat(x + 123, y + 456)
295300 # "add_add_cat_add_model": cat(x + 123, y + 456) + 789
296301 def get_graph_module (
297- self , model_name , x_shape , y_shape , concated_shape , concat_dim
302+ self ,
303+ model_name : str ,
304+ x_shape : List [int ],
305+ y_shape : List [int ],
306+ concated_shape : List [int ],
307+ concat_dim : int ,
298308 ) -> GraphModule :
299309 builder = GraphBuilder ()
300310 x = builder .placeholder ("x" , torch .ones (* x_shape , dtype = torch .float32 ))
@@ -346,6 +356,7 @@ def get_graph_module(
346356
347357 raise ValueError (f"Unknown model name { model_name } " )
348358
359+ # pyre-ignore[56]
349360 @parameterized .expand (
350361 [
351362 (
@@ -366,7 +377,12 @@ def get_graph_module(
366377 name_func = lambda f , _ , param : f"{ f .__name__ } _{ param .args [0 ]} " ,
367378 )
368379 def test_cat_optimized (
369- self , _ , x_shape , y_shape , concated_shape , concat_dim
380+ self ,
381+ _ ,
382+ x_shape : List [int ],
383+ y_shape : List [int ],
384+ concated_shape : List [int ],
385+ concat_dim : int ,
370386 ) -> None :
371387 original = self .get_graph_module (
372388 "add_add_cat_model" , x_shape , y_shape , concated_shape , concat_dim
@@ -379,6 +395,7 @@ def test_cat_optimized(
379395 self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 1 )
380396 self .verify_nop_memory_alloc (graph_module )
381397
398+ # pyre-ignore[56]
382399 @parameterized .expand (
383400 [
384401 (
@@ -392,7 +409,12 @@ def test_cat_optimized(
392409 name_func = lambda f , _ , param : f"{ f .__name__ } _{ param .args [0 ]} " ,
393410 )
394411 def test_cat_not_optimized (
395- self , _ , x_shape , y_shape , concated_shape , concat_dim
412+ self ,
413+ _ ,
414+ x_shape : List [int ],
415+ y_shape : List [int ],
416+ concated_shape : List [int ],
417+ concat_dim : int ,
396418 ) -> None :
397419 original = self .get_graph_module (
398420 "add_add_cat_model" , x_shape , y_shape , concated_shape , concat_dim
@@ -404,6 +426,7 @@ def test_cat_not_optimized(
404426 self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 1 )
405427 self .verify_nop_memory_alloc (graph_module )
406428
429+ # pyre-ignore[56]
407430 @parameterized .expand (
408431 [
409432 (
@@ -426,7 +449,13 @@ def test_cat_not_optimized(
426449 name_func = lambda f , _ , param : f"{ f .__name__ } _{ param .args [0 ]} " ,
427450 )
428451 def test_cat_not_graph_output (
429- self , _ , x_shape , y_shape , concated_shape , concat_dim , expected_cat_nodes
452+ self ,
453+ _ ,
454+ x_shape : List [int ],
455+ y_shape : List [int ],
456+ concated_shape : List [int ],
457+ concat_dim : int ,
458+ expected_cat_nodes : int ,
430459 ) -> None :
431460 original = self .get_graph_module (
432461 "add_add_cat_add_model" , x_shape , y_shape , concated_shape , concat_dim
@@ -493,13 +522,14 @@ def test_optimize_cat_with_slice(self) -> None:
493522 self .assertEqual (count_node (graph_module , exir_ops .edge .aten .slice .Tensor ), 1 )
494523 self .verify_nop_memory_alloc (graph_module )
495524
525+ # pyre-ignore[56]
496526 @parameterized .expand (
497527 [
498528 (True ,), # alloc_graph_input
499529 (False ,), # alloc_graph_input
500530 ],
501531 )
502- def test_optimize_cat_with_slice_infeasible (self , alloc_graph_input ) -> None :
532+ def test_optimize_cat_with_slice_infeasible (self , alloc_graph_input : bool ) -> None :
503533 x_shape = [5 , 6 ]
504534 y_shape = [3 , 6 ]
505535 concated_shape = [8 , 6 ]
0 commit comments