diff --git a/flash_mla/flash_mla_interface.py b/flash_mla/flash_mla_interface.py index 4d27621..b78af52 100644 --- a/flash_mla/flash_mla_interface.py +++ b/flash_mla/flash_mla_interface.py @@ -62,7 +62,7 @@ def flash_mla_with_kvcache( if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) if indices is not None: - assert causal == False, "causal must be `false` if sparse attention is enabled." + assert not causal, "causal must be False when sparse attention is enabled (indices provided)" out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla( q, k_cache, diff --git a/setup.py b/setup.py index 15fa671..7e5424d 100644 --- a/setup.py +++ b/setup.py @@ -104,7 +104,8 @@ def get_nvcc_thread_args(): try: cmd = ['git', 'rev-parse', '--short', 'HEAD'] rev = '+' + subprocess.check_output(cmd).decode('ascii').rstrip() -except Exception as _: +except (subprocess.CalledProcessError, FileNotFoundError, OSError): + # Fallback to timestamp if git is not available or not in a git repo now = datetime.now() date_time_str = now.strftime("%Y-%m-%d-%H-%M-%S") rev = '+' + date_time_str