Skip to content

Commit 8b72347

Browse files
committed
Use double quotes consistently
1 parent 01d56b9 commit 8b72347

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

op_builder/evoformer_attn.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,24 +15,24 @@ class EvoformerAttnBuilder(CUDAOpBuilder):
1515
def __init__(self, name=None):
1616
name = self.NAME if name is None else name
1717
super().__init__(name=name)
18-
self.cutlass_path = os.environ.get('CUTLASS_PATH')
18+
self.cutlass_path = os.environ.get("CUTLASS_PATH")
1919
# Target GPU architecture
2020
# Current useful values: >70, >75, >80, see gemm_kernel_utils.h
2121
# For modern GPUs, >80 is obfiously the right value
22-
self.gpu_arch = os.environ.get('DS_EVOFORMER_GPU_ARCH')
22+
self.gpu_arch = os.environ.get("DS_EVOFORMER_GPU_ARCH")
2323

2424
def absolute_name(self):
25-
return f'deepspeed.ops.{self.NAME}_op'
25+
return f"deepspeed.ops.{self.NAME}_op"
2626

2727
def extra_ldflags(self):
2828
if not self.is_rocm_pytorch():
29-
return ['-lcurand']
29+
return ["-lcurand"]
3030
else:
3131
return []
3232

3333
def sources(self):
34-
src_dir = 'csrc/deepspeed4science/evoformer_attn'
35-
return [f'{src_dir}/attention.cpp', f'{src_dir}/attention_back.cu', f'{src_dir}/attention_cu.cu']
34+
src_dir = "csrc/deepspeed4science/evoformer_attn"
35+
return [f"{src_dir}/attention.cpp", f"{src_dir}/attention_back.cu", f"{src_dir}/attention_cu.cu"]
3636

3737
def nvcc_args(self):
3838
args = super().nvcc_args()
@@ -69,9 +69,9 @@ def is_compatible(self, verbose=False):
6969
except (RuntimeError, ImportError):
7070
return False
7171
# Check version in case it is a CUTLASS_PATH points to a CUTLASS checkout
72-
if os.path.exists(f'{self.cutlass_path}/CHANGELOG.md'):
73-
with open(f'{self.cutlass_path}/CHANGELOG.md', 'r') as f:
74-
if '3.1.0' not in f.read():
72+
if os.path.exists(f"{self.cutlass_path}/CHANGELOG.md"):
73+
with open(f"{self.cutlass_path}/CHANGELOG.md", "r") as f:
74+
if "3.1.0" not in f.read():
7575
if verbose:
7676
self.warning("Please use CUTLASS version >= 3.1.0")
7777
return False
@@ -81,7 +81,7 @@ def is_compatible(self, verbose=False):
8181
if not os.environ.get("DS_IGNORE_CUDA_DETECTION"):
8282
if not self.is_rocm_pytorch() and torch.cuda.is_available(): #ignore-cuda
8383
sys_cuda_major, _ = installed_cuda_version()
84-
torch_cuda_major = int(torch.version.cuda.split('.')[0])
84+
torch_cuda_major = int(torch.version.cuda.split(".")[0])
8585
cuda_capability = torch.cuda.get_device_properties(0).major #ignore-cuda
8686
if cuda_capability < 7:
8787
if verbose:

0 commit comments

Comments
 (0)