1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515import gc
16+ import inspect
1617
1718import torch
1819
@@ -54,19 +55,16 @@ def _test_torch_compile(self, torch_dtype=torch.bfloat16):
5455 # `fullgraph=True` ensures no graph breaks
5556 pipe .transformer .compile (fullgraph = True )
5657
57- with torch ._dynamo .config .patch (error_on_recompile = True ):
58- for _ in range (2 ):
59- # small resolutions to ensure speedy execution.
60- pipe ("a dog" , num_inference_steps = 2 , max_sequence_length = 16 , height = 256 , width = 256 )
58+ # small resolutions to ensure speedy execution.
59+ pipe ("a dog" , num_inference_steps = 2 , max_sequence_length = 16 , height = 256 , width = 256 )
6160
6261 def _test_torch_compile_with_cpu_offload (self , torch_dtype = torch .bfloat16 ):
6362 pipe = self ._init_pipeline (self .quantization_config , torch_dtype )
6463 pipe .enable_model_cpu_offload ()
6564 pipe .transformer .compile ()
6665
67- for _ in range (2 ):
68- # small resolutions to ensure speedy execution.
69- pipe ("a dog" , num_inference_steps = 2 , max_sequence_length = 16 , height = 256 , width = 256 )
66+ # small resolutions to ensure speedy execution.
67+ pipe ("a dog" , num_inference_steps = 2 , max_sequence_length = 16 , height = 256 , width = 256 )
7068
7169 def _test_torch_compile_with_group_offload_leaf (self , torch_dtype = torch .bfloat16 , * , use_stream : bool = False ):
7270 torch ._dynamo .config .cache_size_limit = 1000
@@ -85,15 +83,17 @@ def _test_torch_compile_with_group_offload_leaf(self, torch_dtype=torch.bfloat16
8583 if torch .device (component .device ).type == "cpu" :
8684 component .to ("cuda" )
8785
88- for _ in range (2 ):
89- # small resolutions to ensure speedy execution.
90- pipe ("a dog" , num_inference_steps = 2 , max_sequence_length = 16 , height = 256 , width = 256 )
86+ # small resolutions to ensure speedy execution.
87+ pipe ("a dog" , num_inference_steps = 2 , max_sequence_length = 16 , height = 256 , width = 256 )
9188
9289 def test_torch_compile (self ):
9390 self ._test_torch_compile ()
9491
9592 def test_torch_compile_with_cpu_offload (self ):
9693 self ._test_torch_compile_with_cpu_offload ()
9794
98- def test_torch_compile_with_group_offload_leaf (self ):
99- self ._test_torch_compile_with_group_offload_leaf ()
95+ def test_torch_compile_with_group_offload_leaf (self , use_stream = False ):
96+ for cls in inspect .getmro (self .__class__ ):
97+ if "test_torch_compile_with_group_offload_leaf" in cls .__dict__ and cls is not QuantCompileTests :
98+ return
99+ self ._test_torch_compile_with_group_offload_leaf (use_stream = use_stream )
0 commit comments