9494 'sha256' : (
9595 '59e37f570ba5f3d7148028e96684d77f347d49a54e3722189782fc9b17d201c0'
9696 ),
97+ 'target' : 'resources/nvidia-driver.run'
98+ },
99+ 'visualization' : {
100+ 'url' : 'https://go.microsoft.com/fwlink/?linkid=849941' ,
101+ 'sha256' : (
102+ 'f5e39c9abf6d48d9883cd61d8fec8c67f05c9d6a7cc8b450af0efa790fbbd1a7'
103+ ),
104+ 'target' : 'resources/nvidia-driver-grid.run'
97105 },
98106 'license' : (
99107 'http://www.nvidia.com/content/DriverDownload-March2009'
100108 '/licence.php?lang=us'
101109 ),
102- 'target' : 'resources/nvidia-driver.run'
103110}
104111_NODEPREP_FILE = (
105112 'shipyard_nodeprep.sh' ,
@@ -289,14 +296,8 @@ def _setup_nvidia_driver_package(blob_client, config, vm_size):
289296 :rtype: pathlib.Path
290297 :return: package path
291298 """
292- if settings .is_gpu_compute_pool (vm_size ):
293- gpu_type = 'compute'
294- elif settings .is_gpu_visualization_pool (vm_size ):
295- gpu_type = 'visualization'
296- raise RuntimeError (
297- ('pool consisting of {} nodes require gpu driver '
298- 'configuration' ).format (vm_size ))
299- pkg = pathlib .Path (_ROOT_PATH , _NVIDIA_DRIVER ['target' ])
299+ gpu_type = settings .get_gpu_type_from_vm_size (vm_size )
300+ pkg = pathlib .Path (_ROOT_PATH , _NVIDIA_DRIVER [gpu_type ]['target' ])
300301 # check to see if package is downloaded
301302 if (not pkg .exists () or
302303 util .compute_sha256_for_file (pkg , False ) !=
@@ -314,14 +315,14 @@ def _setup_nvidia_driver_package(blob_client, config, vm_size):
314315 logger .info ('NVIDIA Software License accepted' )
315316 # download driver
316317 logger .debug ('downloading NVIDIA driver to {}' .format (
317- _NVIDIA_DRIVER ['target' ]))
318+ _NVIDIA_DRIVER [gpu_type ][ 'target' ]))
318319 response = requests .get (_NVIDIA_DRIVER [gpu_type ]['url' ], stream = True )
319320 with pkg .open ('wb' ) as f :
320321 for chunk in response .iter_content (chunk_size = _REQUEST_CHUNK_SIZE ):
321322 if chunk :
322323 f .write (chunk )
323324 logger .debug ('wrote {} bytes to {}' .format (
324- pkg .stat ().st_size , _NVIDIA_DRIVER ['target' ]))
325+ pkg .stat ().st_size , _NVIDIA_DRIVER [gpu_type ][ 'target' ]))
325326 # check sha256
326327 if (util .compute_sha256_for_file (pkg , False ) !=
327328 _NVIDIA_DRIVER [gpu_type ]['sha256' ]):
@@ -833,7 +834,9 @@ def _add_pool(
833834 blob_client , config , pool_settings .vm_size )
834835 _rflist .append ((gpu_driver .name , gpu_driver ))
835836 else :
836- gpu_driver = pathlib .Path (_NVIDIA_DRIVER ['target' ])
837+ gpu_type = settings .get_gpu_type_from_vm_size (
838+ pool_settings .vm_size )
839+ gpu_driver = pathlib .Path (_NVIDIA_DRIVER [gpu_type ]['target' ])
837840 gpupkg = _setup_nvidia_docker_package (blob_client , config )
838841 _rflist .append ((gpupkg .name , gpupkg ))
839842 gpu_env = '{}:{}:{}' .format (
0 commit comments