@@ -15,24 +15,24 @@ class EvoformerAttnBuilder(CUDAOpBuilder):
1515 def __init__ (self , name = None ):
1616 name = self .NAME if name is None else name
1717 super ().__init__ (name = name )
18- self .cutlass_path = os .environ .get (' CUTLASS_PATH' )
18+ self .cutlass_path = os .environ .get (" CUTLASS_PATH" )
1919 # Target GPU architecture
2020 # Current useful values: >70, >75, >80, see gemm_kernel_utils.h
2121 # For modern GPUs, >80 is obfiously the right value
22- self .gpu_arch = os .environ .get (' DS_EVOFORMER_GPU_ARCH' )
22+ self .gpu_arch = os .environ .get (" DS_EVOFORMER_GPU_ARCH" )
2323
2424 def absolute_name (self ):
25- return f' deepspeed.ops.{ self .NAME } _op'
25+ return f" deepspeed.ops.{ self .NAME } _op"
2626
2727 def extra_ldflags (self ):
2828 if not self .is_rocm_pytorch ():
29- return [' -lcurand' ]
29+ return [" -lcurand" ]
3030 else :
3131 return []
3232
3333 def sources (self ):
34- src_dir = ' csrc/deepspeed4science/evoformer_attn'
35- return [f' { src_dir } /attention.cpp' , f' { src_dir } /attention_back.cu' , f' { src_dir } /attention_cu.cu' ]
34+ src_dir = " csrc/deepspeed4science/evoformer_attn"
35+ return [f" { src_dir } /attention.cpp" , f" { src_dir } /attention_back.cu" , f" { src_dir } /attention_cu.cu" ]
3636
3737 def nvcc_args (self ):
3838 args = super ().nvcc_args ()
@@ -69,9 +69,9 @@ def is_compatible(self, verbose=False):
6969 except (RuntimeError , ImportError ):
7070 return False
7171 # Check version in case it is a CUTLASS_PATH points to a CUTLASS checkout
72- if os .path .exists (f' { self .cutlass_path } /CHANGELOG.md' ):
73- with open (f' { self .cutlass_path } /CHANGELOG.md' , 'r' ) as f :
74- if ' 3.1.0' not in f .read ():
72+ if os .path .exists (f" { self .cutlass_path } /CHANGELOG.md" ):
73+ with open (f" { self .cutlass_path } /CHANGELOG.md" , "r" ) as f :
74+ if " 3.1.0" not in f .read ():
7575 if verbose :
7676 self .warning ("Please use CUTLASS version >= 3.1.0" )
7777 return False
@@ -81,7 +81,7 @@ def is_compatible(self, verbose=False):
8181 if not os .environ .get ("DS_IGNORE_CUDA_DETECTION" ):
8282 if not self .is_rocm_pytorch () and torch .cuda .is_available (): #ignore-cuda
8383 sys_cuda_major , _ = installed_cuda_version ()
84- torch_cuda_major = int (torch .version .cuda .split ('.' )[0 ])
84+ torch_cuda_major = int (torch .version .cuda .split ("." )[0 ])
8585 cuda_capability = torch .cuda .get_device_properties (0 ).major #ignore-cuda
8686 if cuda_capability < 7 :
8787 if verbose :
0 commit comments