@@ -725,17 +725,17 @@ def count_nodes(graph_module, target):
725725 )
726726
727727 def test_edge_dialect_non_core_aten_ops (self ):
728- class LinalgNorm (torch .nn .Module ):
728+ class LinalgRank (torch .nn .Module ):
729729 def __init__ (self ):
730730 super ().__init__ ()
731731
732732 def forward (self , x : torch .Tensor ) -> torch .Tensor :
733- return torch .linalg .norm (x )
733+ return torch .linalg .matrix_rank (x )
734734
735735 from torch ._export .verifier import SpecViolationError
736736
737- input = torch .arange ( 9 , dtype = torch .float ) - 4
738- ep = torch .export .export (LinalgNorm (), (input ,), strict = True )
737+ input = torch .ones (( 9 , 9 , 9 ), dtype = torch .float )
738+ ep = torch .export .export (LinalgRank (), (input ,), strict = True )
739739
740740 # aten::linalg_norm is not a core op, so it should error out
741741 with self .assertRaises (SpecViolationError ):
@@ -748,9 +748,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
748748 ep ,
749749 compile_config = EdgeCompileConfig (
750750 _check_ir_validity = True ,
751- _core_aten_ops_exception_list = [
752- torch .ops .aten .linalg_vector_norm .default
753- ],
751+ _core_aten_ops_exception_list = [torch .ops .aten ._linalg_svd .default ],
754752 ),
755753 )
756754 except SpecViolationError :
0 commit comments