Skip to content

Commit ab98f0b

Browse files
authored
avoid calling gc.collect and cuda.empty_cache (#34514)
* update * update * update * update * update --------- Co-authored-by: ydshieh <[email protected]>
1 parent dca93ca commit ab98f0b

24 files changed

+77
-94
lines changed

src/transformers/testing_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import contextlib
1717
import doctest
1818
import functools
19+
import gc
1920
import importlib
2021
import inspect
2122
import logging
@@ -2679,3 +2680,10 @@ def compare_pipeline_output_to_hub_spec(output, hub_spec):
26792680
if unexpected_keys:
26802681
error.append(f"Keys in pipeline output that are not in Hub spec: {unexpected_keys}")
26812682
raise KeyError("\n".join(error))
2683+
2684+
2685+
@require_torch
2686+
def cleanup(device: str, gc_collect=False):
2687+
if gc_collect:
2688+
gc.collect()
2689+
backend_empty_cache(device)

tests/models/clvp/test_feature_extraction_clvp.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import gc
1716
import itertools
1817
import os
1918
import random
@@ -24,7 +23,13 @@
2423
from datasets import Audio, load_dataset
2524

2625
from transformers import ClvpFeatureExtractor
27-
from transformers.testing_utils import check_json_file_has_correct_format, require_torch, slow
26+
from transformers.testing_utils import (
27+
check_json_file_has_correct_format,
28+
cleanup,
29+
require_torch,
30+
slow,
31+
torch_device,
32+
)
2833
from transformers.utils.import_utils import is_torch_available
2934

3035
from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin
@@ -116,8 +121,7 @@ def setUp(self):
116121
def tearDown(self):
117122
super().tearDown()
118123
# clean-up as much as possible GPU memory occupied by PyTorch
119-
gc.collect()
120-
torch.cuda.empty_cache()
124+
cleanup(torch_device)
121125

122126
# Copied from transformers.tests.models.whisper.test_feature_extraction_whisper.WhisperFeatureExtractionTest.test_feat_extract_from_and_save_pretrained
123127
def test_feat_extract_from_and_save_pretrained(self):

tests/models/clvp/test_modeling_clvp.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
# limitations under the License.
1515
"""Testing suite for the PyTorch Clvp model."""
1616

17-
import gc
1817
import tempfile
1918
import unittest
2019

@@ -23,6 +22,7 @@
2322

2423
from transformers import ClvpConfig, ClvpDecoderConfig, ClvpEncoderConfig
2524
from transformers.testing_utils import (
25+
cleanup,
2626
require_torch,
2727
slow,
2828
torch_device,
@@ -174,8 +174,7 @@ def setUp(self):
174174
def tearDown(self):
175175
super().tearDown()
176176
# clean-up as much as possible GPU memory occupied by PyTorch
177-
gc.collect()
178-
torch.cuda.empty_cache()
177+
cleanup(torch_device)
179178

180179
def test_config(self):
181180
self.encoder_config_tester.run_common_tests()
@@ -294,8 +293,7 @@ def setUp(self):
294293
def tearDown(self):
295294
super().tearDown()
296295
# clean-up as much as possible GPU memory occupied by PyTorch
297-
gc.collect()
298-
torch.cuda.empty_cache()
296+
cleanup(torch_device)
299297

300298
def test_model(self):
301299
config_and_inputs = self.model_tester.prepare_config_and_inputs()
@@ -421,8 +419,7 @@ def setUp(self):
421419
def tearDown(self):
422420
super().tearDown()
423421
# clean-up as much as possible GPU memory occupied by PyTorch
424-
gc.collect()
425-
torch.cuda.empty_cache()
422+
cleanup(torch_device)
426423

427424
def test_model(self):
428425
config_and_inputs = self.model_tester.prepare_config_and_inputs()
@@ -571,8 +568,7 @@ def setUp(self):
571568
def tearDown(self):
572569
super().tearDown()
573570
# clean-up as much as possible GPU memory occupied by PyTorch
574-
gc.collect()
575-
torch.cuda.empty_cache()
571+
cleanup(torch_device, gc_collect=True)
576572

577573
def test_conditional_encoder(self):
578574
with torch.no_grad():

tests/models/ctrl/test_modeling_ctrl.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,10 @@
1313
# limitations under the License.
1414

1515

16-
import gc
1716
import unittest
1817

1918
from transformers import CTRLConfig, is_torch_available
20-
from transformers.testing_utils import backend_empty_cache, require_torch, slow, torch_device
19+
from transformers.testing_utils import cleanup, require_torch, slow, torch_device
2120

2221
from ...generation.test_utils import GenerationTesterMixin
2322
from ...test_configuration_common import ConfigTester
@@ -235,8 +234,7 @@ def setUp(self):
235234
def tearDown(self):
236235
super().tearDown()
237236
# clean-up as much as possible GPU memory occupied by PyTorch
238-
gc.collect()
239-
backend_empty_cache(torch_device)
237+
cleanup(torch_device)
240238

241239
def test_config(self):
242240
self.config_tester.run_common_tests()
@@ -261,8 +259,7 @@ class CTRLModelLanguageGenerationTest(unittest.TestCase):
261259
def tearDown(self):
262260
super().tearDown()
263261
# clean-up as much as possible GPU memory occupied by PyTorch
264-
gc.collect()
265-
backend_empty_cache(torch_device)
262+
cleanup(torch_device, gc_collect=True)
266263

267264
@slow
268265
def test_lm_generate_ctrl(self):

tests/models/gpt2/test_modeling_gpt2.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,14 @@
1515

1616

1717
import datetime
18-
import gc
1918
import math
2019
import unittest
2120

2221
import pytest
2322

2423
from transformers import GPT2Config, is_torch_available
2524
from transformers.testing_utils import (
26-
backend_empty_cache,
25+
cleanup,
2726
require_flash_attn,
2827
require_torch,
2928
require_torch_gpu,
@@ -542,8 +541,7 @@ def setUp(self):
542541
def tearDown(self):
543542
super().tearDown()
544543
# clean-up as much as possible GPU memory occupied by PyTorch
545-
gc.collect()
546-
backend_empty_cache(torch_device)
544+
cleanup(torch_device)
547545

548546
def test_config(self):
549547
self.config_tester.run_common_tests()
@@ -753,8 +751,7 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase):
753751
def tearDown(self):
754752
super().tearDown()
755753
# clean-up as much as possible GPU memory occupied by PyTorch
756-
gc.collect()
757-
backend_empty_cache(torch_device)
754+
cleanup(torch_device, gc_collect=True)
758755

759756
def _test_lm_generate_gpt2_helper(
760757
self,

tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from parameterized import parameterized
1919

2020
from transformers import GPTBigCodeConfig, is_torch_available
21-
from transformers.testing_utils import require_torch, slow, torch_device
21+
from transformers.testing_utils import cleanup, require_torch, slow, torch_device
2222

2323
from ...generation.test_utils import GenerationTesterMixin
2424
from ...test_configuration_common import ConfigTester
@@ -422,9 +422,9 @@ def setUp(self):
422422
self.config_tester = ConfigTester(self, config_class=GPTBigCodeConfig, n_embd=37)
423423

424424
def tearDown(self):
425-
import gc
426-
427-
gc.collect()
425+
super().tearDown()
426+
# clean-up as much as possible GPU memory occupied by PyTorch
427+
cleanup(torch_device)
428428

429429
def test_config(self):
430430
self.config_tester.run_common_tests()

tests/models/idefics2/test_modeling_idefics2.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
"""Testing suite for the PyTorch Idefics2 model."""
1616

1717
import copy
18-
import gc
1918
import tempfile
2019
import unittest
2120
from io import BytesIO
@@ -31,6 +30,7 @@
3130
is_vision_available,
3231
)
3332
from transformers.testing_utils import (
33+
cleanup,
3434
require_bitsandbytes,
3535
require_flash_attn,
3636
require_torch,
@@ -583,8 +583,7 @@ def setUp(self):
583583
)
584584

585585
def tearDown(self):
586-
gc.collect()
587-
torch.cuda.empty_cache()
586+
cleanup(torch_device, gc_collect=True)
588587

589588
@slow
590589
@require_torch_multi_gpu

tests/models/idefics3/test_modeling_idefics3.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
"""Testing suite for the PyTorch Idefics3 model."""
1616

1717
import copy
18-
import gc
1918
import unittest
2019
from io import BytesIO
2120

@@ -26,7 +25,7 @@
2625
is_torch_available,
2726
is_vision_available,
2827
)
29-
from transformers.testing_utils import require_bitsandbytes, require_torch, slow, torch_device
28+
from transformers.testing_utils import cleanup, require_bitsandbytes, require_torch, slow, torch_device
3029

3130
from ...generation.test_utils import GenerationTesterMixin
3231
from ...test_configuration_common import ConfigTester
@@ -497,8 +496,7 @@ def setUp(self):
497496
)
498497

499498
def tearDown(self):
500-
gc.collect()
501-
torch.cuda.empty_cache()
499+
cleanup(torch_device, gc_collect=True)
502500

503501
@slow
504502
@unittest.skip("multi-gpu tests are disabled for now")

tests/models/llama/test_modeling_llama.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
# limitations under the License.
1515
"""Testing suite for the PyTorch LLaMA model."""
1616

17-
import gc
1817
import tempfile
1918
import unittest
2019

@@ -25,7 +24,7 @@
2524
from transformers import AutoTokenizer, LlamaConfig, StaticCache, is_torch_available, set_seed
2625
from transformers.generation.configuration_utils import GenerationConfig
2726
from transformers.testing_utils import (
28-
backend_empty_cache,
27+
cleanup,
2928
require_flash_attn,
3029
require_read_token,
3130
require_torch,
@@ -891,8 +890,7 @@ def test_export_static_cache(self):
891890
@require_torch_accelerator
892891
class Mask4DTestHard(unittest.TestCase):
893892
def tearDown(self):
894-
gc.collect()
895-
backend_empty_cache(torch_device)
893+
cleanup(torch_device, gc_collect=True)
896894

897895
def setUp(self):
898896
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

tests/models/llava/test_modeling_llava.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
# limitations under the License.
1515
"""Testing suite for the PyTorch Llava model."""
1616

17-
import gc
1817
import unittest
1918

2019
import requests
@@ -28,6 +27,7 @@
2827
is_vision_available,
2928
)
3029
from transformers.testing_utils import (
30+
cleanup,
3131
require_bitsandbytes,
3232
require_torch,
3333
require_torch_gpu,
@@ -307,8 +307,7 @@ def setUp(self):
307307
self.processor = AutoProcessor.from_pretrained("llava-hf/bakLlava-v1-hf")
308308

309309
def tearDown(self):
310-
gc.collect()
311-
torch.cuda.empty_cache()
310+
cleanup(torch_device, gc_collect=True)
312311

313312
@slow
314313
@require_bitsandbytes

0 commit comments

Comments
 (0)