diff --git a/python/setup.py b/python/setup.py index 95e0519b37..65388d8664 100644 --- a/python/setup.py +++ b/python/setup.py @@ -119,13 +119,13 @@ def find_visual_studio(version_ranges): for version_range in version_ranges: command = [ str(vswhere), "-version", version_range, "-requires", "Microsoft.VisualStudio.Component.VC.Tools.x86.x64", - "-property", "installationPath", "-prerelease" + "-products", "*", "-property", "installationPath", "-prerelease" ] try: output = subprocess.check_output(command, text=True).strip() if output: - return output + return output.split("\n")[0] except subprocess.CalledProcessError: continue @@ -146,6 +146,13 @@ def set_env_vars(vs_path, arch="x64"): os.environ[var] = value +def initialize_visual_studio_env(version_ranges, arch="x64"): + vs_path = find_visual_studio(version_ranges) + if not vs_path: + raise EnvironmentError("Visual Studio not found in specified version ranges.") + set_env_vars(vs_path, arch) + + # Taken from https://github.com/pytorch/pytorch/blob/master/tools/setup_helpers/env.py def check_env_flag(name: str, default: str = "") -> bool: return os.getenv(name, default).upper() in ["ON", "1", "YES", "TRUE", "Y"] @@ -447,10 +454,7 @@ def build_extension(self, ext): lit_dir = shutil.which('lit') ninja_dir = shutil.which('ninja') if platform.system() == "Windows": - vs_path = find_visual_studio(["[17.0,18.0)", "[16.0,17.0)"]) - env = set_env_vars(vs_path) - if not vs_path: - raise EnvironmentError("Visual Studio 2019 or 2022 not found.") + initialize_visual_studio_env(["[17.0,18.0)", "[16.0,17.0)"]) # lit is used by the test suite thirdparty_cmake_args = get_thirdparty_packages([get_llvm_package_info()]) thirdparty_cmake_args += self.get_pybind11_cmake_args() diff --git a/python/triton/runtime/CLFinder.py b/python/triton/runtime/CLFinder.py index 2021e0b04f..1a489a2e3c 100644 --- a/python/triton/runtime/CLFinder.py +++ b/python/triton/runtime/CLFinder.py @@ -19,13 +19,13 @@ def find_visual_studio(version_ranges): for version_range in version_ranges: command = [ str(vswhere), "-version", version_range, "-requires", "Microsoft.VisualStudio.Component.VC.Tools.x86.x64", - "-property", "installationPath", "-prerelease" + "-products", "*", "-property", "installationPath", "-prerelease" ] try: output = subprocess.check_output(command, text=True).strip() if output: - return output + return output.split("\n")[0] except subprocess.CalledProcessError: continue @@ -37,7 +37,7 @@ def set_env_vars(vs_path, arch="x64"): if not vcvarsall_path.exists(): raise FileNotFoundError(f"vcvarsall.bat not found in expected path: {vcvarsall_path}") - command = f'call "{vcvarsall_path}" {arch} && set' + command = ["call", vcvarsall_path, arch, "&&", "set"] output = subprocess.check_output(command, shell=True, text=True) for line in output.splitlines():