99import os
1010import tempfile
1111import unittest
12- from typing import List , Optional , Tuple
12+ from typing import Callable , List , Optional , Tuple
1313
1414import executorch .exir as exir
1515
7171from functorch .experimental import control_flow
7272
7373from torch import nn
74+ from torch ._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
7475from torch .export import export
7576from torch .export .graph_signature import InputKind , InputSpec , TensorArgument
7677from torch .fx import GraphModule , subgraph_rewriter
@@ -121,39 +122,97 @@ def foo_out(
121122 return a + 1 , None
122123
123124
125+ def simple_promote_dtype (
126+ dtype : torch .dtype , promotion_kind : ELEMENTWISE_TYPE_PROMOTION_KIND
127+ ) -> torch .dtype :
128+ if promotion_kind == ELEMENTWISE_TYPE_PROMOTION_KIND .DEFAULT :
129+ return dtype
130+ if promotion_kind == ELEMENTWISE_TYPE_PROMOTION_KIND .INT_TO_FLOAT :
131+ return dtype if dtype .is_floating_point else torch .float
132+ else :
133+ raise Exception (f"Unsupported promotion kind { promotion_kind } " )
134+
135+
136+ def count_nodes_with_target_asserting_arguments_have_dtype (
137+ self , module , target , arg_dtype
138+ ) -> int :
139+ count = 0
140+ for node in module .graph .nodes :
141+ if node .op == "call_function" and node .target == target :
142+ count += 1
143+ for arg in node .args :
144+ self .assertEqual (arg .meta ["val" ].dtype , arg_dtype )
145+ return count
146+
147+
124148class TestPasses (unittest .TestCase ):
125149 @classmethod
126150 def setUpClass (cls ) -> None :
127151 register_additional_test_aten_ops ()
128152
129153 def test_remove_mixed_type_operators (self ) -> None :
130- def count_nodes_with_target_asserting_arguments_have_dtype (
131- new_graph_module , target , arg_dtype
132- ):
133- count = 0
134- for node in new_graph_module .graph .nodes :
135- if node .op == "call_function" and node .target == target :
136- count += 1
137- for arg in node .args :
138- self .assertEqual (arg .meta ["val" ].dtype , arg_dtype )
139- return count
140-
141- class Add (torch .nn .Module ):
142- def forward (self , x : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
143- return (x + y ) + x
144-
145- class Mult (torch .nn .Module ):
146- def forward (self , x : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
147- return x * y
148-
149- class Minimum (torch .nn .Module ):
150- def forward (self , x : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
151- return torch .minimum (x , y )
154+ def make_module (fwd : Callable [[torch .Tensor , torch .Tensor ], torch .Tensor ]):
155+ class Module (torch .nn .Module ):
156+ def forward (self , x : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
157+ return fwd (x , y )
158+
159+ return Module
160+
161+ Add = make_module (lambda x , y : (x + y ) + x )
162+ Mult = make_module (lambda x , y : x * y )
163+ Minimum = make_module (torch .minimum )
164+ DivWithoutMode = make_module (torch .div )
165+ DivWithNoneMode = make_module (lambda x , y : torch .div (x , y , rounding_mode = None ))
166+ DivWithTruncMode = make_module (
167+ lambda x , y : torch .div (x , y , rounding_mode = "trunc" )
168+ )
169+ DivWithFloorMode = make_module (
170+ lambda x , y : torch .div (x , y , rounding_mode = "floor" )
171+ )
152172
153- for module , op , expected_count in (
154- (Add , exir_ops .edge .aten .add .Tensor , 2 ),
155- (Mult , exir_ops .edge .aten .mul .Tensor , 1 ),
156- (Minimum , exir_ops .edge .aten .minimum .default , 1 ),
173+ for module , op , expected_count , promotion_kind in (
174+ (
175+ Add ,
176+ exir_ops .edge .aten .add .Tensor ,
177+ 2 ,
178+ ELEMENTWISE_TYPE_PROMOTION_KIND .DEFAULT ,
179+ ),
180+ (
181+ Mult ,
182+ exir_ops .edge .aten .mul .Tensor ,
183+ 1 ,
184+ ELEMENTWISE_TYPE_PROMOTION_KIND .DEFAULT ,
185+ ),
186+ (
187+ Minimum ,
188+ exir_ops .edge .aten .minimum .default ,
189+ 1 ,
190+ ELEMENTWISE_TYPE_PROMOTION_KIND .DEFAULT ,
191+ ),
192+ (
193+ DivWithoutMode ,
194+ exir_ops .edge .aten .div .Tensor ,
195+ 1 ,
196+ ELEMENTWISE_TYPE_PROMOTION_KIND .INT_TO_FLOAT ,
197+ ),
198+ (
199+ DivWithNoneMode ,
200+ exir_ops .edge .aten .div .Tensor_mode ,
201+ 1 ,
202+ ELEMENTWISE_TYPE_PROMOTION_KIND .INT_TO_FLOAT ,
203+ ),
204+ (
205+ DivWithTruncMode ,
206+ exir_ops .edge .aten .div .Tensor_mode ,
207+ 1 ,
208+ ELEMENTWISE_TYPE_PROMOTION_KIND .DEFAULT ,
209+ ),
210+ (
211+ DivWithFloorMode ,
212+ exir_ops .edge .aten .div .Tensor_mode ,
213+ 1 ,
214+ ELEMENTWISE_TYPE_PROMOTION_KIND .DEFAULT ,
215+ ),
157216 ):
158217 for second_arg_dtype in (torch .int64 , torch .float , torch .double ):
159218 int_tensor = torch .tensor ([[1 , 2 , 3 ]], dtype = torch .int64 )
@@ -166,8 +225,9 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
166225 new_graph_module = new_prog .exported_program ().graph_module
167226 self .assertIsNotNone (new_graph_module )
168227
228+ promoted_type = simple_promote_dtype (second_arg_dtype , promotion_kind )
169229 count = count_nodes_with_target_asserting_arguments_have_dtype (
170- new_graph_module , op , second_arg_dtype
230+ self , new_graph_module , op , promoted_type
171231 )
172232 self .assertEqual (count , expected_count )
173233
0 commit comments