|
40 | 40 | from apache_beam.runners.internal import names |
41 | 41 | from apache_beam.runners.portability import stager |
42 | 42 |
|
| 43 | +from concurrent.futures import ThreadPoolExecutor |
| 44 | +from concurrent.futures import as_completed |
| 45 | + |
43 | 46 | _LOGGER = logging.getLogger(__name__) |
44 | 47 |
|
45 | 48 | try: |
@@ -913,6 +916,59 @@ def test_populate_requirements_cache_with_local_files(self): |
913 | 916 | self.assertNotIn('fake_pypi', extra_packages_contents) |
914 | 917 | self.assertIn('local_package', extra_packages_contents) |
915 | 918 |
|
| 919 | + def test_requirements_cache_creation_no_race_condition(self): |
| 920 | + base_cache_dir = self.make_temp_dir() |
| 921 | + cache_dir = os.path.join(base_cache_dir, 'test-requirements-cache') |
| 922 | + # Ensure the directory doesn't exist initially |
| 923 | + if os.path.exists(cache_dir): |
| 924 | + shutil.rmtree(cache_dir) |
| 925 | + |
| 926 | + source_dir = self.make_temp_dir() |
| 927 | + requirements_file = os.path.join(source_dir, stager.REQUIREMENTS_FILE) |
| 928 | + self.create_temp_file(requirements_file, 'requests>=2.0.0\n') |
| 929 | + |
| 930 | + def create_resources_with_cache(): |
| 931 | + temp_dir = tempfile.mkdtemp() |
| 932 | + try: |
| 933 | + options = PipelineOptions() |
| 934 | + self.update_options(options) |
| 935 | + setup_options = options.view_as(SetupOptions) |
| 936 | + setup_options.requirements_file = requirements_file |
| 937 | + setup_options.requirements_cache = cache_dir |
| 938 | + # This should create the cache directory if it doesn't exist |
| 939 | + stager.Stager.create_job_resources( |
| 940 | + options, |
| 941 | + temp_dir, |
| 942 | + populate_requirements_cache=self.populate_requirements_cache) |
| 943 | + return True, None |
| 944 | + except Exception as e: |
| 945 | + return False, e |
| 946 | + finally: |
| 947 | + if os.path.exists(temp_dir): |
| 948 | + shutil.rmtree(temp_dir) |
| 949 | + |
| 950 | + # Run multiple threads concurrently to create |
| 951 | + # resources with the same cache dir. |
| 952 | + num_threads = 10 |
| 953 | + successes = 0 |
| 954 | + with ThreadPoolExecutor(max_workers=num_threads) as executor: |
| 955 | + futures = [ |
| 956 | + executor.submit(create_resources_with_cache) |
| 957 | + for _ in range(num_threads) |
| 958 | + ] |
| 959 | + |
| 960 | + for future in as_completed(futures): |
| 961 | + success, _ = future.result() |
| 962 | + if success: |
| 963 | + successes += 1 |
| 964 | + # All threads should succeed |
| 965 | + self.assertEqual( |
| 966 | + successes, |
| 967 | + num_threads, |
| 968 | + f"Expected all {num_threads} threads to pass, but got errors.") |
| 969 | + # Verify that the cache directory exists |
| 970 | + self.assertTrue(os.path.isdir(cache_dir)) |
| 971 | + |
916 | 972 |
|
917 | 973 | class TestStager(stager.Stager): |
918 | 974 | def stage_artifact(self, local_path_to_artifact, artifact_name, sha256): |
|
0 commit comments