1515
1616import numpy as np
1717from op_test import get_device_place , is_custom_device
18- from test_sparse_attention_op import get_cuda_version
1918
2019import paddle
2120import paddle .nn .functional as F
@@ -131,9 +130,8 @@ def fused_multi_transformer_int8(
131130
132131@unittest .skipIf (
133132 not (core .is_compiled_with_cuda () or is_custom_device ())
134- or get_cuda_version () < 11020
135133 or paddle .device .cuda .get_device_capability ()[0 ] < 8 ,
136- "FusedMultiTransformerInt8 requires CUDA >= 11.2 and CUDA_ARCH >= 8" ,
134+ "FusedMultiTransformerInt8 requires CUDA_ARCH >= 8" ,
137135)
138136class TestFusedMultiTransformerInt8Op (unittest .TestCase ):
139137 def setUp (self ):
@@ -788,9 +786,8 @@ def test_fused_multi_transformer_op(self):
788786
789787@unittest .skipIf (
790788 not (core .is_compiled_with_cuda () or is_custom_device ())
791- or get_cuda_version () < 11020
792789 or paddle .device .cuda .get_device_capability ()[0 ] < 8 ,
793- "FusedMultiTransformerInt8 requires CUDA >= 11.2 and CUDA_ARCH >= 8" ,
790+ "FusedMultiTransformerInt8 requires CUDA_ARCH >= 8" ,
794791)
795792class TestFusedMultiTransformerInt8OpFp16 (TestFusedMultiTransformerInt8Op ):
796793 def config (self ):
@@ -801,9 +798,8 @@ def config(self):
801798
802799@unittest .skipIf (
803800 not (core .is_compiled_with_cuda () or is_custom_device ())
804- or get_cuda_version () < 11020
805801 or paddle .device .cuda .get_device_capability ()[0 ] < 8 ,
806- "FusedMultiTransformerInt8 requires CUDA >= 11.2 and CUDA_ARCH >= 8" ,
802+ "FusedMultiTransformerInt8 requires CUDA_ARCH >= 8" ,
807803)
808804class TestFusedMultiTransformerInt8OpCacheKV (TestFusedMultiTransformerInt8Op ):
809805 def config (self ):
@@ -817,9 +813,8 @@ def config(self):
817813
818814@unittest .skipIf (
819815 not (core .is_compiled_with_cuda () or is_custom_device ())
820- or get_cuda_version () < 11020
821816 or paddle .device .cuda .get_device_capability ()[0 ] < 8 ,
822- "FusedMultiTransformerInt8 requires CUDA >= 11.2 and CUDA_ARCH >= 8" ,
817+ "FusedMultiTransformerInt8 requires CUDA_ARCH >= 8" ,
823818)
824819class TestFusedMultiTransformerInt8OpCacheKVFp16 (
825820 TestFusedMultiTransformerInt8Op
@@ -834,9 +829,8 @@ def config(self):
834829
835830@unittest .skipIf (
836831 not (core .is_compiled_with_cuda () or is_custom_device ())
837- or get_cuda_version () < 11020
838832 or paddle .device .cuda .get_device_capability ()[0 ] < 8 ,
839- "FusedMultiTransformerInt8 requires CUDA >= 11.2 and CUDA_ARCH >= 8" ,
833+ "FusedMultiTransformerInt8 requires CUDA_ARCH >= 8" ,
840834)
841835class TestFusedMultiTransformerInt8OpGenCacheKV (
842836 TestFusedMultiTransformerInt8Op
@@ -849,9 +843,8 @@ def config(self):
849843
850844@unittest .skipIf (
851845 not (core .is_compiled_with_cuda () or is_custom_device ())
852- or get_cuda_version () < 11020
853846 or paddle .device .cuda .get_device_capability ()[0 ] < 8 ,
854- "FusedMultiTransformerInt8 requires CUDA >= 11.2 and CUDA_ARCH >= 8" ,
847+ "FusedMultiTransformerInt8 requires CUDA_ARCH >= 8" ,
855848)
856849class TestFusedMultiTransformerInt8OpGenCacheKVFp16 (
857850 TestFusedMultiTransformerInt8Op
@@ -866,9 +859,8 @@ def config(self):
866859
867860@unittest .skipIf (
868861 not (core .is_compiled_with_cuda () or is_custom_device ())
869- or get_cuda_version () < 11020
870862 or paddle .device .cuda .get_device_capability ()[0 ] < 8 ,
871- "FusedMultiTransformerInt8 requires CUDA >= 11.2 and CUDA_ARCH >= 8" ,
863+ "FusedMultiTransformerInt8 requires CUDA_ARCH >= 8" ,
872864)
873865class TestFusedMultiTransformerInt8OpPostLayerNormFp16 (
874866 TestFusedMultiTransformerInt8Op
@@ -882,9 +874,8 @@ def config(self):
882874
883875@unittest .skipIf (
884876 not (core .is_compiled_with_cuda () or is_custom_device ())
885- or get_cuda_version () < 11020
886877 or paddle .device .cuda .get_device_capability ()[0 ] < 8 ,
887- "FusedMultiTransformerInt8 requires CUDA >= 11.2 and CUDA_ARCH >= 8" ,
878+ "FusedMultiTransformerInt8 requires CUDA_ARCH >= 8" ,
888879)
889880class TestFusedMultiTransformerInt8OpCacheKVPostLayerNorm (
890881 TestFusedMultiTransformerInt8Op
@@ -900,9 +891,8 @@ def config(self):
900891
901892@unittest .skipIf (
902893 not (core .is_compiled_with_cuda () or is_custom_device ())
903- or get_cuda_version () < 11020
904894 or paddle .device .cuda .get_device_capability ()[0 ] < 8 ,
905- "FusedMultiTransformerInt8 requires CUDA >= 11.2 and CUDA_ARCH >= 8" ,
895+ "FusedMultiTransformerInt8 requires CUDA_ARCH >= 8" ,
906896)
907897class TestFusedMultiTransformerInt8OpCacheKVPostLayerNormFp16 (
908898 TestFusedMultiTransformerInt8Op
@@ -918,9 +908,8 @@ def config(self):
918908
919909@unittest .skipIf (
920910 not (core .is_compiled_with_cuda () or is_custom_device ())
921- or get_cuda_version () < 11020
922911 or paddle .device .cuda .get_device_capability ()[0 ] < 8 ,
923- "FusedMultiTransformerInt8 requires CUDA >= 11.2 and CUDA_ARCH >= 8" ,
912+ "FusedMultiTransformerInt8 requires CUDA_ARCH >= 8" ,
924913)
925914class TestFusedMultiTransformerInt8OpGenCacheKVPostLayerNorm (
926915 TestFusedMultiTransformerInt8Op
@@ -934,9 +923,8 @@ def config(self):
934923
935924@unittest .skipIf (
936925 not (core .is_compiled_with_cuda () or is_custom_device ())
937- or get_cuda_version () < 11020
938926 or paddle .device .cuda .get_device_capability ()[0 ] < 8 ,
939- "FusedMultiTransformerInt8 requires CUDA >= 11.2 and CUDA_ARCH >= 8" ,
927+ "FusedMultiTransformerInt8 requires CUDA_ARCH >= 8" ,
940928)
941929class TestFusedMultiTransformerInt8OpGenCacheKVPostLayerNormFp16 (
942930 TestFusedMultiTransformerInt8Op
0 commit comments