Skip to content

Commit 6e7884f

Browse files
committed
use lru cache to replace global cache variable
1 parent 5a5e829 commit 6e7884f

File tree

2 files changed

+7
-18
lines changed

2 files changed

+7
-18
lines changed

.ci/scripts/test-cuda-build.sh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
set -exu
99

1010
# The generic Linux job chooses to use base env, not the one setup by the image
11-
# eval "$(conda shell.bash hook)"
12-
# CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]")
13-
# conda activate "${CONDA_ENV}"
11+
eval "$(conda shell.bash hook)"
12+
CONDA_ENV=$(conda info --envs | awk '/base/ {print $2}')
13+
conda activate "${CONDA_ENV}"
1414

1515
CUDA_VERSION=${1:-"12.6"}
1616

install_utils.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
77

8+
import functools
89
import os
910
import platform
1011
import re
@@ -101,14 +102,11 @@ def _get_pytorch_cuda_url(cuda_version, torch_nightly_url_base):
101102
return f"{torch_nightly_url_base}/{cuda_suffix}"
102103

103104

104-
# Global variable for caching torch URL
105-
_torch_url_cache = ""
106-
107-
105+
@functools.lru_cache(maxsize=1)
108106
def determine_torch_url(torch_nightly_url_base, supported_cuda_versions):
109107
"""
110108
Determine the appropriate PyTorch installation URL based on CUDA availability and CMAKE_ARGS.
111-
Uses caching to avoid redundant CUDA detection and print statements.
109+
Uses @functools.lru_cache to avoid redundant CUDA detection and print statements.
112110
113111
Args:
114112
torch_nightly_url_base: Base URL for PyTorch nightly packages
@@ -117,17 +115,10 @@ def determine_torch_url(torch_nightly_url_base, supported_cuda_versions):
117115
Returns:
118116
URL string for PyTorch packages
119117
"""
120-
global _torch_url_cache
121-
122-
# Return cached URL if already determined
123-
if _torch_url_cache:
124-
return _torch_url_cache
125-
126118
# Check if CUDA delegate is enabled
127119
if not _is_cuda_enabled():
128120
print("CUDA delegate not enabled, using CPU-only PyTorch")
129-
_torch_url_cache = f"{torch_nightly_url_base}/cpu"
130-
return _torch_url_cache
121+
return f"{torch_nightly_url_base}/cpu"
131122

132123
print("CUDA delegate enabled, detecting CUDA version...")
133124

@@ -141,8 +132,6 @@ def determine_torch_url(torch_nightly_url_base, supported_cuda_versions):
141132
torch_url = _get_pytorch_cuda_url(cuda_version, torch_nightly_url_base)
142133
print(f"Using PyTorch URL: {torch_url}")
143134

144-
# Cache the result
145-
_torch_url_cache = torch_url
146135
return torch_url
147136

148137

0 commit comments

Comments
 (0)