Skip to content

Commit 8c29056

Browse files
committed
tweak cuda 12/13
1 parent 11f24e3 commit 8c29056

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

fastrerandomize/R/FRR_BuildBackend.R

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ build_backend <- function(conda_env = "fastrerandomize_env", conda = "auto"){
2525
# Create a new conda environment
2626
reticulate::conda_create(envname = conda_env,
2727
conda = conda,
28-
python_version = "3.11")
28+
python_version = "3.12")
2929

3030
os <- Sys.info()[["sysname"]]
3131
machine <- Sys.info()["machine"]
@@ -52,20 +52,20 @@ build_backend <- function(conda_env = "fastrerandomize_env", conda = "auto"){
5252
# Prefer CUDA 13 if the driver is new enough; otherwise CUDA 12; else CPU fallback
5353
if (!is.na(drv_major) && drv_major >= 580) {
5454
msg("Driver %s detected (>=580): installing JAX 0.5.0 CUDA 13 wheels.", drv[1])
55-
tryCatch(pip_install('jax[cuda13]==0.5.0'), error = function(e) {
55+
tryCatch(pip_install('jax[cuda13]'), error = function(e) {
5656
msg("CUDA 13 wheels failed (%s); falling back to CUDA 12.", e$message)
57-
pip_install('jax[cuda12]==0.5.0')
57+
pip_install('jax[cuda12]')
5858
})
5959
} else if (!is.na(drv_major) && drv_major >= 525) {
6060
msg("Driver %s detected (>=525,<580): installing JAX 0.5.0 CUDA 12 wheels.", drv[1])
61-
pip_install('jax[cuda12]==0.5.0')
61+
pip_install('jax[cuda12]')
6262
} else {
6363
msg("Driver %s too old for CUDA wheels; installing CPU-only JAX 0.5.0.", drv[1])
64-
pip_install('jax==0.5.0')
64+
pip_install('jax')
6565
}
6666
} else {
6767
msg("Installing CPU-only JAX 0.5.0.")
68-
pip_install('jax==0.5.0')
68+
pip_install('jax')
6969
}
7070
}
7171

0 commit comments

Comments
 (0)