@@ -772,15 +772,18 @@ function install_nvidia_cudnn() {
772772}
773773
774774function install_pytorch() {
775- if test -f " ${workdir} /complete/pytorch" ; then return ; fi
775+ is_complete pytorch && return
776+
776777 local env
777778 env=$( get_metadata_attribute ' gpu-conda-env' ' dpgce' )
778779 local mc3=/opt/conda/miniconda3
779780 local envpath=" ${mc3} /envs/${env} "
781+ if [[ " ${env} " == " base" ]]; then
782+ echo " WARNING: installing to base environment known to cause solve issues" ; envpath=" ${mc3} " ; fi
780783 # Set numa node to 0 for all GPUs
781784 for f in $( ls /sys/module/nvidia/drivers/pci:nvidia/* /numa_node) ; do echo 0 > ${f} ; done
782785
783- local build_tarball=" pytorch_${_shortname} _cuda${CUDA_VERSION} .tar.gz"
786+ local build_tarball=" pytorch_${env} _ ${ _shortname} _cuda${CUDA_VERSION} .tar.gz"
784787 local local_tarball=" ${workdir} /${build_tarball} "
785788 local gcs_tarball=" ${pkg_bucket} /conda/${_shortname} /${build_tarball} "
786789
@@ -805,17 +808,28 @@ function install_pytorch() {
805808 if test -d " ${envpath} " ; then verb=install ; fi
806809 cudart_spec=" cuda-cudart"
807810 if le_cuda11 ; then cudart_spec=" cudatoolkit" ; fi
811+
812+ # Install pytorch and company to this environment
808813 " ${mc3} /bin/mamba" " ${verb} " -n " ${env} " \
809814 -c conda-forge -c nvidia -c rapidsai \
810815 numba pytorch tensorflow[and-cuda] rapids pyspark \
811816 " cuda-version<=${CUDA_VERSION} " " ${cudart_spec} "
817+
818+ # Install jupyter kernel in this environment
819+ " ${envpath} /bin/python3" -m pip install ipykernel
820+
821+ # package environment and cache in GCS
812822 pushd " ${envpath} "
813823 tar czf " ${local_tarball} " .
814824 popd
815825 gcloud storage cp " ${local_tarball} " " ${gcs_tarball} "
816826 if gcloud storage ls " ${gcs_tarball} .building" ; then gcloud storage rm " ${gcs_tarball} .building" || true ; fi
817827 fi
818- touch " ${workdir} /complete/pytorch"
828+
829+ # register the environment as a selectable kernel
830+ " ${envpath} /bin/python3" -m ipykernel install --name " ${env} " --display-name " Python (${env} )"
831+
832+ mark_complete pytorch
819833}
820834
821835function configure_dkms_certs() {
@@ -2067,11 +2081,11 @@ function harden_sshd_config() {
20672081 feature_map[" kex-gss" ]=" gssapikexalgorithms" ; fi
20682082 for ftr in " ${! feature_map[@]} " ; do
20692083 export feature=${feature_map[$ftr]}
2070- sshd_config_line=$(
2084+ sshd_config_line=" ${feature} $(
20712085 (sshd -T | awk " /^${feature} / {print \$ 2}" | sed -e ' s/,/\n/g' ;
20722086 ssh -Q " ${ftr} " ) \
2073- | sort -u | perl -e ' @a=grep{!/( sha1| md5)/ig}<STDIN>;
2074- print("$ENV{feature} ",join(q",",map{ chomp; $_ }@a), $/) if "@a" ' )
2087+ | sort -u | grep -v -ie sha1 -e md5 | paste -sd " , " - ) "
2088+
20752089 grep -iv " ^${feature} " /etc/ssh/sshd_config > /tmp/sshd_config_new
20762090 echo " $sshd_config_line " >> /tmp/sshd_config_new
20772091 # TODO: test whether sshd will reload with this change before mv
0 commit comments