Skip to content

Commit 6eb5588

Browse files
authored
clean get_cuda_version() < 11020 in tests (PaddlePaddle#75811)
* fix * fix
1 parent 66483e0 commit 6eb5588

File tree

2 files changed

+19
-31
lines changed

2 files changed

+19
-31
lines changed

test/ir/pir/fused_pass/test_fused_weight_only_linear_pass.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ def get_cuda_version():
3939

4040

4141
@unittest.skipIf(
42-
not core.is_compiled_with_cuda() or get_cuda_version() < 11020,
43-
"weight_only_linear requires CUDA >= 11.2",
42+
not core.is_compiled_with_cuda(),
43+
"weight_only_linear requires compiled with CUDA",
4444
)
4545
class TestFusedWeightOnlyLinearPass_WithBias(PassTest):
4646
def is_config_valid(self, w_shape, bias_shape):
@@ -146,8 +146,8 @@ def test_check_output(self):
146146

147147

148148
@unittest.skipIf(
149-
not core.is_compiled_with_cuda() or get_cuda_version() < 11020,
150-
"weight_only_linear requires CUDA >= 11.2",
149+
not core.is_compiled_with_cuda(),
150+
"weight_only_linear requires compiled with CUDA",
151151
)
152152
class TestFusedWeightOnlyLinearPass_NoBias(PassTest):
153153
def get_valid_op_map(self, dtype, w_shape):
@@ -233,8 +233,8 @@ def test_check_output(self):
233233

234234

235235
@unittest.skipIf(
236-
not core.is_compiled_with_cuda() or get_cuda_version() < 11020,
237-
"weight_only_linear requires CUDA >= 11.2",
236+
not core.is_compiled_with_cuda(),
237+
"weight_only_linear requires compiled with CUDA",
238238
)
239239
class TestFusedWeightOnlyLinearPass_Weight_Only_Int8(
240240
TestFusedWeightOnlyLinearPass_NoBias
@@ -252,8 +252,8 @@ def setUp(self):
252252

253253

254254
@unittest.skipIf(
255-
not core.is_compiled_with_cuda() or get_cuda_version() < 11020,
256-
"weight_only_linear requires CUDA >= 11.2",
255+
not core.is_compiled_with_cuda(),
256+
"weight_only_linear requires compiled with CUDA",
257257
)
258258
class TestFusedWeightOnlyLinearPass_Weight_Only_Int8_WithBias(
259259
TestFusedWeightOnlyLinearPass_WithBias

test/legacy_test/test_fused_multi_transformer_int8_op.py

Lines changed: 11 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
import numpy as np
1717
from op_test import get_device_place, is_custom_device
18-
from test_sparse_attention_op import get_cuda_version
1918

2019
import paddle
2120
import paddle.nn.functional as F
@@ -131,9 +130,8 @@ def fused_multi_transformer_int8(
131130

132131
@unittest.skipIf(
133132
not (core.is_compiled_with_cuda() or is_custom_device())
134-
or get_cuda_version() < 11020
135133
or paddle.device.cuda.get_device_capability()[0] < 8,
136-
"FusedMultiTransformerInt8 requires CUDA >= 11.2 and CUDA_ARCH >= 8",
134+
"FusedMultiTransformerInt8 requires CUDA_ARCH >= 8",
137135
)
138136
class TestFusedMultiTransformerInt8Op(unittest.TestCase):
139137
def setUp(self):
@@ -788,9 +786,8 @@ def test_fused_multi_transformer_op(self):
788786

789787
@unittest.skipIf(
790788
not (core.is_compiled_with_cuda() or is_custom_device())
791-
or get_cuda_version() < 11020
792789
or paddle.device.cuda.get_device_capability()[0] < 8,
793-
"FusedMultiTransformerInt8 requires CUDA >= 11.2 and CUDA_ARCH >= 8",
790+
"FusedMultiTransformerInt8 requires CUDA_ARCH >= 8",
794791
)
795792
class TestFusedMultiTransformerInt8OpFp16(TestFusedMultiTransformerInt8Op):
796793
def config(self):
@@ -801,9 +798,8 @@ def config(self):
801798

802799
@unittest.skipIf(
803800
not (core.is_compiled_with_cuda() or is_custom_device())
804-
or get_cuda_version() < 11020
805801
or paddle.device.cuda.get_device_capability()[0] < 8,
806-
"FusedMultiTransformerInt8 requires CUDA >= 11.2 and CUDA_ARCH >= 8",
802+
"FusedMultiTransformerInt8 requires CUDA_ARCH >= 8",
807803
)
808804
class TestFusedMultiTransformerInt8OpCacheKV(TestFusedMultiTransformerInt8Op):
809805
def config(self):
@@ -817,9 +813,8 @@ def config(self):
817813

818814
@unittest.skipIf(
819815
not (core.is_compiled_with_cuda() or is_custom_device())
820-
or get_cuda_version() < 11020
821816
or paddle.device.cuda.get_device_capability()[0] < 8,
822-
"FusedMultiTransformerInt8 requires CUDA >= 11.2 and CUDA_ARCH >= 8",
817+
"FusedMultiTransformerInt8 requires CUDA_ARCH >= 8",
823818
)
824819
class TestFusedMultiTransformerInt8OpCacheKVFp16(
825820
TestFusedMultiTransformerInt8Op
@@ -834,9 +829,8 @@ def config(self):
834829

835830
@unittest.skipIf(
836831
not (core.is_compiled_with_cuda() or is_custom_device())
837-
or get_cuda_version() < 11020
838832
or paddle.device.cuda.get_device_capability()[0] < 8,
839-
"FusedMultiTransformerInt8 requires CUDA >= 11.2 and CUDA_ARCH >= 8",
833+
"FusedMultiTransformerInt8 requires CUDA_ARCH >= 8",
840834
)
841835
class TestFusedMultiTransformerInt8OpGenCacheKV(
842836
TestFusedMultiTransformerInt8Op
@@ -849,9 +843,8 @@ def config(self):
849843

850844
@unittest.skipIf(
851845
not (core.is_compiled_with_cuda() or is_custom_device())
852-
or get_cuda_version() < 11020
853846
or paddle.device.cuda.get_device_capability()[0] < 8,
854-
"FusedMultiTransformerInt8 requires CUDA >= 11.2 and CUDA_ARCH >= 8",
847+
"FusedMultiTransformerInt8 requires CUDA_ARCH >= 8",
855848
)
856849
class TestFusedMultiTransformerInt8OpGenCacheKVFp16(
857850
TestFusedMultiTransformerInt8Op
@@ -866,9 +859,8 @@ def config(self):
866859

867860
@unittest.skipIf(
868861
not (core.is_compiled_with_cuda() or is_custom_device())
869-
or get_cuda_version() < 11020
870862
or paddle.device.cuda.get_device_capability()[0] < 8,
871-
"FusedMultiTransformerInt8 requires CUDA >= 11.2 and CUDA_ARCH >= 8",
863+
"FusedMultiTransformerInt8 requires CUDA_ARCH >= 8",
872864
)
873865
class TestFusedMultiTransformerInt8OpPostLayerNormFp16(
874866
TestFusedMultiTransformerInt8Op
@@ -882,9 +874,8 @@ def config(self):
882874

883875
@unittest.skipIf(
884876
not (core.is_compiled_with_cuda() or is_custom_device())
885-
or get_cuda_version() < 11020
886877
or paddle.device.cuda.get_device_capability()[0] < 8,
887-
"FusedMultiTransformerInt8 requires CUDA >= 11.2 and CUDA_ARCH >= 8",
878+
"FusedMultiTransformerInt8 requires CUDA_ARCH >= 8",
888879
)
889880
class TestFusedMultiTransformerInt8OpCacheKVPostLayerNorm(
890881
TestFusedMultiTransformerInt8Op
@@ -900,9 +891,8 @@ def config(self):
900891

901892
@unittest.skipIf(
902893
not (core.is_compiled_with_cuda() or is_custom_device())
903-
or get_cuda_version() < 11020
904894
or paddle.device.cuda.get_device_capability()[0] < 8,
905-
"FusedMultiTransformerInt8 requires CUDA >= 11.2 and CUDA_ARCH >= 8",
895+
"FusedMultiTransformerInt8 requires CUDA_ARCH >= 8",
906896
)
907897
class TestFusedMultiTransformerInt8OpCacheKVPostLayerNormFp16(
908898
TestFusedMultiTransformerInt8Op
@@ -918,9 +908,8 @@ def config(self):
918908

919909
@unittest.skipIf(
920910
not (core.is_compiled_with_cuda() or is_custom_device())
921-
or get_cuda_version() < 11020
922911
or paddle.device.cuda.get_device_capability()[0] < 8,
923-
"FusedMultiTransformerInt8 requires CUDA >= 11.2 and CUDA_ARCH >= 8",
912+
"FusedMultiTransformerInt8 requires CUDA_ARCH >= 8",
924913
)
925914
class TestFusedMultiTransformerInt8OpGenCacheKVPostLayerNorm(
926915
TestFusedMultiTransformerInt8Op
@@ -934,9 +923,8 @@ def config(self):
934923

935924
@unittest.skipIf(
936925
not (core.is_compiled_with_cuda() or is_custom_device())
937-
or get_cuda_version() < 11020
938926
or paddle.device.cuda.get_device_capability()[0] < 8,
939-
"FusedMultiTransformerInt8 requires CUDA >= 11.2 and CUDA_ARCH >= 8",
927+
"FusedMultiTransformerInt8 requires CUDA_ARCH >= 8",
940928
)
941929
class TestFusedMultiTransformerInt8OpGenCacheKVPostLayerNormFp16(
942930
TestFusedMultiTransformerInt8Op

0 commit comments

Comments
 (0)