Skip to content

Commit 76e25e6

Browse files
committed
[Test] Correctly skip tests on non-Hopper GPUs
The previous logic for skipping hardware-specific tests was incorrect, causing tests to run on unsupported platforms. This patch ensures that tests with `D=128`, `CombaConfig`, or `GatedDeltaNetConfig` are now properly skipped on all non-Hopper architectures to prevent spurious failures.
1 parent c6b21dc commit 76e25e6

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

fla/modules/convolution.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -584,6 +584,14 @@ def __init__(
584584
"Consider installing `causal_conv1d` to enable the CUDA implementation."
585585
)
586586
self.use_fast_conv1d = False
587+
if bias is False:
588+
# There is a bug in https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/cpp_functions.py#L135
589+
self.use_fast_conv1d = False
590+
warnings.warn(
591+
"The `use_fast_conv1d` parameter is set to `True`, but bias is set to `False`. "
592+
"Switching to the Triton implementation instead. "
593+
"Since there is a bug in causal_conv1d that does not support bias during backward pass."
594+
)
587595

588596
def extra_repr(self):
589597
s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'

tests/test_model.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,10 @@ def test_model(
8080
dtype: torch.dtype,
8181
use_l2warp: bool
8282
):
83-
if not is_nvidia_hopper and D == 128 or config_class in [CombaConfig, GatedDeltaNetConfig]:
84-
# Due to lack of shared memory
83+
if not is_nvidia_hopper and (D == 128 or config_class in [CombaConfig, GatedDeltaNetConfig]):
84+
# Skip D=128 for non-Hopper GPUs
85+
# CombaConfig and GatedDeltaNetConfig are not supported on non-Hopper GPUs
86+
# as they require specific shared memory
8587
pytest.skip("D=128 is only Tested on Hopper GPUs")
8688
if config_class in [
8789
ABCConfig, ForgettingTransformerConfig, LinearAttentionConfig, LightNetConfig,

0 commit comments

Comments
 (0)