Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 2 additions & 15 deletions op_builder/evoformer_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from .builder import CUDAOpBuilder, installed_cuda_version
import os
from packaging.version import Version


class EvoformerAttnBuilder(CUDAOpBuilder):
Expand All @@ -29,18 +30,6 @@ def sources(self):
src_dir = 'csrc/deepspeed4science/evoformer_attn'
return [f'{src_dir}/attention.cpp', f'{src_dir}/attention_back.cu', f'{src_dir}/attention_cu.cu']

def nvcc_args(self):
args = super().nvcc_args()
try:
import torch
except ImportError:
self.warning("Please install torch if trying to pre-compile kernels")
return args
major = torch.cuda.get_device_properties(0).major #ignore-cuda
minor = torch.cuda.get_device_properties(0).minor #ignore-cuda
args.append(f"-DGPU_ARCH={major}{minor}")
return args

def is_compatible(self, verbose=False):
try:
import torch
Expand All @@ -66,9 +55,7 @@ def is_compatible(self, verbose=False):
if verbose:
self.warning("Please pip install nvidia-cutlass if trying to pre-compile kernels")
return False
cutlass_major, cutlass_minor = cutlass.__version__.split('.')[:2]
cutlass_compatible = (int(cutlass_major) >= 3 and int(cutlass_minor) >= 1)
if not cutlass_compatible:
if Version(cutlass.__version__) < Version('3.1.0'):
if verbose:
self.warning("Please use CUTLASS version >= 3.1.0")
return False
Expand Down