@@ -127,85 +127,44 @@ def setUpClass(cls) -> None:
127127 register_additional_test_aten_ops ()
128128
129129 def test_remove_mixed_type_operators (self ) -> None :
130+ def count_nodes_with_target_asserting_arguments_have_dtype (
131+ new_graph_module , target , arg_dtype
132+ ):
133+ count = 0
134+ for node in new_graph_module .graph .nodes :
135+ if node .op == "call_function" and node .target == target :
136+ count += 1
137+ for arg in node .args :
138+ self .assertEqual (arg .meta ["val" ].dtype , arg_dtype )
139+ return count
140+
130141 class Add (torch .nn .Module ):
131142 def forward (self , x : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
132143 return (x + y ) + x
133144
134- add = Add ()
135-
136- int_tensor = torch .tensor ([[1 , 2 , 3 ]])
137- float_tensor = torch .tensor ([[1.0 , 2.0 , 3.0 ]])
138- edge_prog = to_edge (export (add , (int_tensor , float_tensor ), strict = True ))
139-
140- new_prog = edge_prog .transform ([RemoveMixedTypeOperators ()])
141- new_graph_module = new_prog .exported_program ().graph_module
142- self .assertIsNotNone (new_graph_module )
143-
144- add_count = 0
145-
146- for node in new_graph_module .graph .nodes :
147- if (
148- node .op == "call_function"
149- and node .target == exir_ops .edge .aten .add .Tensor
150- ):
151- add_count += 1
152- node_args = node .args
153- for arg in node_args :
154- self .assertEqual (arg .meta ["val" ].dtype , torch .float )
155-
156- self .assertEqual (add_count , 2 )
157-
158- double_tensor = torch .tensor ([[1.0 , 2.0 , 3.0 ]])
159- double_tensor = double_tensor .to (torch .double )
160-
161- double_prog = to_edge (export (add , (int_tensor , double_tensor ), strict = True ))
162-
163- double_prog .transform ([RemoveMixedTypeOperators ()])
164- new_graph_module_double = double_prog .exported_program ().graph_module
165- self .assertIsNotNone (new_graph_module_double )
166-
167- add_count_double = 0
168-
169- for node in new_graph_module_double .graph .nodes :
170- if (
171- node .op == "call_function"
172- and node .target == exir_ops .edge .aten .add .Tensor
173- ):
174- add_count_double += 1
175- node_args = node .args
176- for arg in node_args :
177- self .assertEqual (arg .meta ["val" ].dtype , torch .double )
178-
179- self .assertEqual (add_count_double , 2 )
180-
181145 class Mult (torch .nn .Module ):
182146 def forward (self , x : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
183147 return x * y
184148
185- mult = Mult ()
186-
187- float_tensor_vert = float_tensor . T
188- mult_prog = to_edge ( export ( mult , ( int_tensor , float_tensor_vert ), strict = True ))
189-
190- # graph_module_mult.graph.print_tabular( )
191-
192- mult_prog = mult_prog . transform ([ RemoveMixedTypeOperators ()])
193- new_graph_module_mult = mult_prog . exported_program (). graph_module
194- self . assertIsNotNone ( new_graph_module_mult )
149+ for module , op , expected_count in (
150+ ( Add , exir_ops . edge . aten . add . Tensor , 2 ),
151+ ( Mult , exir_ops . edge . aten . mul . Tensor , 1 ),
152+ ):
153+ for second_arg_dtype in ( torch . int64 , torch . float , torch . double ):
154+ int_tensor = torch . tensor ([[ 1 , 2 , 3 ]], dtype = torch . int64 )
155+ float_tensor = torch . tensor ([[ 1.0 , 2.0 , 3.0 ]], dtype = second_arg_dtype )
156+ edge_prog = to_edge (
157+ export ( module (), ( int_tensor , float_tensor ), strict = True )
158+ )
195159
196- mult_count = 0
160+ new_prog = edge_prog .transform ([RemoveMixedTypeOperators ()])
161+ new_graph_module = new_prog .exported_program ().graph_module
162+ self .assertIsNotNone (new_graph_module )
197163
198- for node in new_graph_module_mult .graph .nodes :
199- if (
200- node .op == "call_function"
201- and node .target == exir_ops .edge .aten .mul .Tensor
202- ):
203- mult_count += 1
204- node_args = node .args
205- for arg in node_args :
206- self .assertEqual (arg .meta ["val" ].dtype , torch .float )
207-
208- self .assertEqual (mult_count , 1 )
164+ count = count_nodes_with_target_asserting_arguments_have_dtype (
165+ new_graph_module , op , second_arg_dtype
166+ )
167+ self .assertEqual (count , expected_count )
209168
210169 def test_remove_noop_pass (self ) -> None :
211170 class Foo (torch .nn .Module ):
0 commit comments