@@ -1279,6 +1279,7 @@ class Module(torch.nn.Module):
12791279 def __init__ (self ):
12801280 super ().__init__ ()
12811281 self .linear = torch .nn .Linear (3 , 3 )
1282+ self .w = torch .randn (3 , 3 )
12821283
12831284 def t (self , val ):
12841285 return val + 1
@@ -1293,8 +1294,11 @@ def false_fn(self, val):
12931294 return self .linear (val ) - self .f (val )
12941295
12951296 def forward (self , pred , x ):
1297+ out = torch .nn .functional .linear (
1298+ x , self .w .to (torch .float16 ).to (torch .float32 )
1299+ )
12961300 return torch .ops .higher_order .cond (
1297- pred , self .true_fn , self .false_fn , [x ]
1301+ pred , self .true_fn , self .false_fn , [out ]
12981302 )
12991303
13001304 mod = Module ()
@@ -1304,14 +1308,48 @@ def forward(self, pred, x):
13041308 export (mod , (pred , x ), strict = True ),
13051309 compile_config = exir .EdgeCompileConfig (_check_ir_validity = False ),
13061310 )
1307- error_msg = r"constant_prop_pass for control flow is not supported yet."
1308-
1309- # TODO(chenlai): enable constant prop pass for control flow
1310- with self .assertRaisesRegex (
1311- RuntimeError ,
1312- error_msg ,
1313- ):
1314- _ = constant_prop_pass (edge .exported_program ())
1311+ expected_out = edge .exported_program ().module ()(pred , x )
1312+
1313+ warn_log = (
1314+ "constant_prop_pass does not constant propagate in control flow modules"
1315+ )
1316+ with self .assertLogs (level = "WARNING" ) as log :
1317+ program = constant_prop_pass (edge .exported_program ())
1318+ self .assertIn (warn_log , log .output [0 ])
1319+
1320+ out = program .module ()(pred , x )
1321+ self .assertTrue (torch .allclose (expected_out , out ))
1322+
1323+ # dtype casts in parent module are const propagated
1324+ FileCheck ().check (
1325+ "executorch_exir_dialects_edge__ops_aten_mm_default(x, _prop_tensor_constant"
1326+ ).run (program .graph_module .code )
1327+
1328+ def test_constant_prop_pass_quant_primitives (self ) -> None :
1329+ class M (torch .nn .Module ):
1330+ def __init__ (self ):
1331+ super ().__init__ ()
1332+ self .w_int = torch .ones (3 , 3 , dtype = torch .int8 )
1333+ self .w_scale = 3.0
1334+ self .w_zero_point = 3
1335+
1336+ def forward (self , x ):
1337+ w_dq = torch .ops .quantized_decomposed .dequantize_per_tensor .default (
1338+ self .w_int , self .w_scale , self .w_zero_point , - 127 , 128 , torch .int8
1339+ )
1340+ return torch .nn .functional .linear (x , w_dq )
1341+
1342+ mod = M ()
1343+ x = torch .randn ([3 ])
1344+ mod (x )
1345+ edge = to_edge (
1346+ export (mod , (x ,), strict = True ),
1347+ compile_config = exir .EdgeCompileConfig (_check_ir_validity = False ),
1348+ )
1349+ constant_prop_pass (edge .exported_program ())
1350+ FileCheck ().check (
1351+ "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default"
1352+ ).run (edge .exported_program ().graph_module .code )
13151353
13161354 def test_mutable_buffers (self ) -> None :
13171355 def count_copies (gm : torch .fx .GraphModule ) -> int :
0 commit comments