@@ -772,15 +772,18 @@ function install_nvidia_cudnn() {
772772}
773773
774774function install_pytorch() {
775- if test -f " ${workdir} /complete/pytorch" ; then return ; fi
775+ is_complete pytorch && return
776+
776777 local env
777778 env=$( get_metadata_attribute ' gpu-conda-env' ' dpgce' )
778779 local mc3=/opt/conda/miniconda3
779780 local envpath=" ${mc3} /envs/${env} "
781+ if [[ " ${env} " == " base" ]]; then
782+ echo " WARNING: installing to base environment known to cause solve issues" ; envpath=" ${mc3} " ; fi
780783 # Set numa node to 0 for all GPUs
781784 for f in $( ls /sys/module/nvidia/drivers/pci:nvidia/* /numa_node) ; do echo 0 > ${f} ; done
782785
783- local build_tarball=" pytorch_${_shortname} _cuda${CUDA_VERSION} .tar.gz"
786+ local build_tarball=" pytorch_${env} _ ${ _shortname} _cuda${CUDA_VERSION} .tar.gz"
784787 local local_tarball=" ${workdir} /${build_tarball} "
785788 local gcs_tarball=" ${pkg_bucket} /conda/${_shortname} /${build_tarball} "
786789
@@ -805,17 +808,28 @@ function install_pytorch() {
805808 if test -d " ${envpath} " ; then verb=install ; fi
806809 cudart_spec=" cuda-cudart"
807810 if le_cuda11 ; then cudart_spec=" cudatoolkit" ; fi
811+
812+ # Install pytorch and company to this environment
808813 " ${mc3} /bin/mamba" " ${verb} " -n " ${env} " \
809814 -c conda-forge -c nvidia -c rapidsai \
810815 numba pytorch tensorflow[and-cuda] rapids pyspark \
811816 " cuda-version<=${CUDA_VERSION} " " ${cudart_spec} "
817+
818+ # Install jupyter kernel in this environment
819+ " ${envpath} /bin/python3" -m pip install ipykernel
820+
821+ # package environment and cache in GCS
812822 pushd " ${envpath} "
813823 tar czf " ${local_tarball} " .
814824 popd
815825 gcloud storage cp " ${local_tarball} " " ${gcs_tarball} "
816826 if gcloud storage ls " ${gcs_tarball} .building" ; then gcloud storage rm " ${gcs_tarball} .building" || true ; fi
817827 fi
818- touch " ${workdir} /complete/pytorch"
828+
829+ # register the environment as a selectable kernel
830+ " ${envpath} /bin/python3" -m ipykernel install --name " ${env} " --display-name " Python (${env} )"
831+
832+ mark_complete pytorch
819833}
820834
821835function configure_dkms_certs() {
0 commit comments