@@ -50,30 +50,29 @@ def _init_pipeline(self, quantization_config, torch_dtype):
5050 )
5151 return pipe
5252
53- def _test_torch_compile (self , quantization_config , torch_dtype = torch .bfloat16 ):
54- pipe = self ._init_pipeline (quantization_config , torch_dtype ).to ("cuda" )
55- # import to ensure fullgraph True
53+ def _test_torch_compile (self , torch_dtype = torch .bfloat16 ):
54+ pipe = self ._init_pipeline (self . quantization_config , torch_dtype ).to ("cuda" )
55+ # `fullgraph=True` ensures no graph breaks
5656 pipe .transformer .compile (fullgraph = True )
5757
58- for _ in range (2 ):
59- # small resolutions to ensure speedy execution.
60- pipe ("a dog" , num_inference_steps = 3 , max_sequence_length = 16 , height = 256 , width = 256 )
58+ with torch ._dynamo .config .patch (error_on_recompile = True ):
59+ for _ in range (2 ):
60+ # small resolutions to ensure speedy execution.
61+ pipe ("a dog" , num_inference_steps = 2 , max_sequence_length = 16 , height = 256 , width = 256 )
6162
62- def _test_torch_compile_with_cpu_offload (self , quantization_config , torch_dtype = torch .bfloat16 ):
63- pipe = self ._init_pipeline (quantization_config , torch_dtype )
63+ def _test_torch_compile_with_cpu_offload (self , torch_dtype = torch .bfloat16 ):
64+ pipe = self ._init_pipeline (self . quantization_config , torch_dtype )
6465 pipe .enable_model_cpu_offload ()
6566 pipe .transformer .compile ()
6667
6768 for _ in range (2 ):
6869 # small resolutions to ensure speedy execution.
69- pipe ("a dog" , num_inference_steps = 3 , max_sequence_length = 16 , height = 256 , width = 256 )
70+ pipe ("a dog" , num_inference_steps = 2 , max_sequence_length = 16 , height = 256 , width = 256 )
7071
71- def _test_torch_compile_with_group_offload_leaf (
72- self , quantization_config , torch_dtype = torch .bfloat16 , * , use_stream : bool = False
73- ):
74- torch ._dynamo .config .cache_size_limit = 10000
72+ def _test_torch_compile_with_group_offload_leaf (self , torch_dtype = torch .bfloat16 , * , use_stream : bool = False ):
73+ torch ._dynamo .config .cache_size_limit = 1000
7574
76- pipe = self ._init_pipeline (quantization_config , torch_dtype )
75+ pipe = self ._init_pipeline (self . quantization_config , torch_dtype )
7776 group_offload_kwargs = {
7877 "onload_device" : torch .device ("cuda" ),
7978 "offload_device" : torch .device ("cpu" ),
@@ -89,4 +88,13 @@ def _test_torch_compile_with_group_offload_leaf(
8988
9089 for _ in range (2 ):
9190 # small resolutions to ensure speedy execution.
92- pipe ("a dog" , num_inference_steps = 3 , max_sequence_length = 16 , height = 256 , width = 256 )
91+ pipe ("a dog" , num_inference_steps = 2 , max_sequence_length = 16 , height = 256 , width = 256 )
92+
93+ def test_torch_compile (self ):
94+ self ._test_torch_compile ()
95+
96+ def test_torch_compile_with_cpu_offload (self ):
97+ self ._test_torch_compile_with_cpu_offload ()
98+
99+ def test_torch_compile_with_group_offload_leaf (self ):
100+ self ._test_torch_compile_with_group_offload_leaf ()
0 commit comments