@@ -25,7 +25,7 @@ build_backend <- function(conda_env = "fastrerandomize_env", conda = "auto"){
2525 # Create a new conda environment
2626 reticulate :: conda_create(envname = conda_env ,
2727 conda = conda ,
28- python_version = " 3.11 " )
28+ python_version = " 3.12 " )
2929
3030 os <- Sys.info()[[" sysname" ]]
3131 machine <- Sys.info()[" machine" ]
@@ -52,20 +52,20 @@ build_backend <- function(conda_env = "fastrerandomize_env", conda = "auto"){
5252 # Prefer CUDA 13 if the driver is new enough; otherwise CUDA 12; else CPU fallback
5353 if (! is.na(drv_major ) && drv_major > = 580 ) {
5454 msg(" Driver %s detected (>=580): installing JAX 0.5.0 CUDA 13 wheels." , drv [1 ])
55- tryCatch(pip_install(' jax[cuda13]==0.5.0 ' ), error = function (e ) {
55+ tryCatch(pip_install(' jax[cuda13]' ), error = function (e ) {
5656 msg(" CUDA 13 wheels failed (%s); falling back to CUDA 12." , e $ message )
57- pip_install(' jax[cuda12]==0.5.0 ' )
57+ pip_install(' jax[cuda12]' )
5858 })
5959 } else if (! is.na(drv_major ) && drv_major > = 525 ) {
6060 msg(" Driver %s detected (>=525,<580): installing JAX 0.5.0 CUDA 12 wheels." , drv [1 ])
61- pip_install(' jax[cuda12]==0.5.0 ' )
61+ pip_install(' jax[cuda12]' )
6262 } else {
6363 msg(" Driver %s too old for CUDA wheels; installing CPU-only JAX 0.5.0." , drv [1 ])
64- pip_install(' jax==0.5.0 ' )
64+ pip_install(' jax' )
6565 }
6666 } else {
6767 msg(" Installing CPU-only JAX 0.5.0." )
68- pip_install(' jax==0.5.0 ' )
68+ pip_install(' jax' )
6969 }
7070 }
7171
0 commit comments