@@ -467,6 +467,60 @@ def is_custom_device():
467467 return False
468468
469469
470+ def check_cudnn_version_and_compute_capability (
471+ min_cudnn_version = None , min_device_capability = None
472+ ):
473+ """
474+ Check if the current environment meets the specified cuDNN version and device capability requirements.
475+
476+ Args:
477+ min_cudnn_version (int, optional): Minimum required cuDNN version. If None, cuDNN version check is skipped.
478+ min_device_capability (int, optional): Minimum required device capability. If None, device capability check is skipped.
479+
480+ Returns:
481+ bool: True if the environment meets the requirements or if using custom device, False otherwise.
482+ """
483+ if is_custom_device ():
484+ return True
485+
486+ if not core .is_compiled_with_cuda ():
487+ return False
488+
489+ # Check cuDNN version if specified
490+ cudnn_check = True
491+ if min_cudnn_version is not None :
492+ cudnn_check = core .cudnn_version () >= min_cudnn_version
493+
494+ # Check device capability if specified
495+ device_check = True
496+ if min_device_capability is not None :
497+ device_check = (
498+ paddle .device .cuda .get_device_capability ()[0 ]
499+ >= min_device_capability
500+ )
501+
502+ return cudnn_check and device_check
503+
504+
505+ def get_cuda_version ():
506+ if paddle .is_compiled_with_cuda ():
507+ import re
508+
509+ result = os .popen ("nvcc --version" ).read ()
510+ regex = r'release (\S+),'
511+ match = re .search (regex , result )
512+ if match :
513+ num = str (match .group (1 ))
514+ integer , decimal = num .split ('.' )
515+ return int (integer ) * 1000 + int (float (decimal ) * 10 )
516+ else :
517+ return - 1
518+ elif is_custom_device ():
519+ return 13000
520+ else :
521+ return - 1
522+
523+
470524@contextmanager
471525def auto_parallel_test_guard (test_info_path , generated_test_file_path ):
472526 test_info_file , generated_test_file = None , None
0 commit comments