@@ -286,13 +286,14 @@ def test_advance_branched_quantize(self) -> None:
286286 @torch .no_grad ()
287287 def test_advance_quantize (self ) -> None :
288288 builder = GraphBuilder ()
289- x = builder . placeholder ( "x" , torch .randn (16 , 1 , 6 , 32 , dtype = torch .float32 ) )
290- weights = builder . placeholder (
291- "weights " , torch . randint ( - 128 , 127 , ( 32 , 32 ), dtype = torch . int8 )
292- )
289+ x_data = torch .randn (16 , 1 , 32 , 6 , dtype = torch .float32 )
290+ weight_data = torch . randint ( - 128 , 127 , ( 32 , 32 ), dtype = torch . int8 )
291+ x = builder . placeholder ( "x " , x_data )
292+ weights = builder . placeholder ( "weights" , weight_data )
293293 full = builder .call_operator (
294294 op = exir_ops .edge .aten .full .default ,
295295 args = ([1 ], - 7 ),
296+ kwargs = {"dtype" : torch .int32 },
296297 )
297298 full_1 = builder .call_operator (
298299 op = exir_ops .edge .aten .full .default ,
@@ -304,7 +305,8 @@ def test_advance_quantize(self) -> None:
304305 )
305306 full_3 = builder .call_operator (
306307 op = exir_ops .edge .aten .full .default ,
307- args = ([12 ], 0.0 ),
308+ args = ([1 ], 0 ),
309+ kwargs = {"dtype" : torch .int32 },
308310 )
309311 permute = builder .call_operator (
310312 op = exir_ops .edge .aten .permute_copy .default ,
@@ -337,8 +339,13 @@ def test_advance_quantize(self) -> None:
337339
338340 p1 = AdvanceQuantizeOpAboveDefInBranchPass ()
339341 tmp_graph = cast (PassResult , p1 (original_graph )).graph_module
340- p2 = AdvanceQuantizeOpAboveDefChainPass ()
341- converted_graph = cast (PassResult , p2 (tmp_graph )).graph_module
342+ result = transform_and_check_numerics (
343+ tmp_graph ,
344+ (x_data , weight_data ),
345+ AdvanceQuantizeOpAboveDefChainPass (),
346+ )
347+ self .assertFalse (result .modified )
348+ converted_graph = result .graph_module
342349 # Assert that permute node is now the successor of the quant node.
343350 self .assertTrue (
344351 get_node_pos (
@@ -349,13 +356,14 @@ def test_advance_quantize(self) -> None:
349356
350357 def test_postpone_dequantize1 (self ) -> None :
351358 builder = GraphBuilder ()
352- x = builder . placeholder ( "x" , torch .randn (1 , 16 , 32 , 6 , dtype = torch .float32 ) )
353- weights = builder . placeholder (
354- "weights " , torch . randint ( - 128 , 127 , ( 6 , 6 ), dtype = torch . int8 )
355- )
359+ x_data = torch .randn (1 , 16 , 32 , 6 , dtype = torch .float32 )
360+ weight_data = torch . randint ( - 128 , 127 , ( 6 , 6 ), dtype = torch . int8 )
361+ x = builder . placeholder ( "x " , x_data )
362+ weights = builder . placeholder ( "weights" , weight_data )
356363 full = builder .call_operator (
357364 op = exir_ops .edge .aten .full .default ,
358365 args = ([1 ], - 7 ),
366+ kwargs = {"dtype" : torch .int32 },
359367 )
360368 full_1 = builder .call_operator (
361369 op = exir_ops .edge .aten .full .default ,
@@ -367,7 +375,8 @@ def test_postpone_dequantize1(self) -> None:
367375 )
368376 full_3 = builder .call_operator (
369377 op = exir_ops .edge .aten .full .default ,
370- args = ([12 ], 0.0 ),
378+ args = ([1 ], 0 ),
379+ kwargs = {"dtype" : torch .int32 },
371380 )
372381 quantize_per_tensor = builder .call_operator (
373382 op = exir_ops .edge .quantized_decomposed .quantize_per_tensor .default ,
@@ -397,8 +406,13 @@ def test_postpone_dequantize1(self) -> None:
397406 )
398407 builder .output ([permute ])
399408 original_graph = builder .get_graph_module ()
400- p = PostponeDequantizeOpBelowUseChainPass ()
401- converted_graph = cast (PassResult , p (original_graph )).graph_module
409+ result = transform_and_check_numerics (
410+ original_graph ,
411+ (x_data , weight_data ),
412+ PostponeDequantizeOpBelowUseChainPass (),
413+ )
414+ self .assertTrue (result .modified )
415+ converted_graph = result .graph_module
402416 # Assert that dequant node is now the successor of the permute node.
403417 self .assertTrue (
404418 get_node_pos (converted_graph , exir_ops .edge .aten .permute_copy .default )
0 commit comments