@@ -1008,7 +1008,11 @@ def test_qnn_backend_rsqrt(self):
1008
1008
self .lower_module_and_test_output (module , sample_input )
1009
1009
1010
1010
def test_qnn_backend_sdpa (self ):
1011
- module = ScaledDotProductAttention () # noqa: F405
1011
+ modules = [
1012
+ ScaledDotProductAttention (), # noqa: F405
1013
+ ScaledDotProductAttention (scale = 0.5 ), # noqa: F405
1014
+ ScaledDotProductAttention (scale = 1.0 ), # noqa: F405
1015
+ ]
1012
1016
mask = torch .tril (torch .randn (1 , 1 , 100 , 100 ))
1013
1017
mask [mask == 0 ] = float ("-inf" )
1014
1018
sample_input = (
@@ -1017,7 +1021,9 @@ def test_qnn_backend_sdpa(self):
1017
1021
torch .randn (1 , 4 , 100 , 64 ),
1018
1022
mask ,
1019
1023
)
1020
- self .lower_module_and_test_output (module , sample_input )
1024
+ for i , module in enumerate (modules ):
1025
+ with self .subTest (i = i ):
1026
+ self .lower_module_and_test_output (module , sample_input )
1021
1027
1022
1028
def test_qnn_backend_sigmoid (self ):
1023
1029
module = Sigmoid () # noqa: F405
@@ -2414,7 +2420,11 @@ def test_qnn_backend_rsqrt(self):
2414
2420
self .lower_module_and_test_output (module , sample_input )
2415
2421
2416
2422
def test_qnn_backend_sdpa (self ):
2417
- module = ScaledDotProductAttention () # noqa: F405
2423
+ modules = [
2424
+ ScaledDotProductAttention (), # noqa: F405
2425
+ ScaledDotProductAttention (scale = 0.5 ), # noqa: F405
2426
+ ScaledDotProductAttention (scale = 1.0 ), # noqa: F405
2427
+ ]
2418
2428
mask = torch .tril (torch .randn (1 , 1 , 100 , 100 ))
2419
2429
mask [mask == 0 ] = torch .finfo (torch .float32 ).min
2420
2430
sample_input = (
@@ -2423,8 +2433,12 @@ def test_qnn_backend_sdpa(self):
2423
2433
torch .randn (1 , 4 , 100 , 64 ),
2424
2434
mask ,
2425
2435
)
2426
- module = self .get_qdq_module (module , sample_input )
2427
- self .lower_module_and_test_output (module , sample_input )
2436
+ for i , module in enumerate (modules ):
2437
+ with self .subTest (i = i ):
2438
+ module = self .get_qdq_module (
2439
+ module , sample_input , quant_dtype = QuantDtype .use_16a8w
2440
+ )
2441
+ self .lower_module_and_test_output (module , sample_input )
2428
2442
2429
2443
def test_qnn_backend_select_copy (self ):
2430
2444
module = SelectCopy () # noqa: F405
@@ -4951,13 +4965,14 @@ def test_gMLP(self):
4951
4965
self .assertGreaterEqual (msg ["top_1" ], 60 )
4952
4966
self .assertGreaterEqual (msg ["top_5" ], 85 )
4953
4967
4954
- def test_mobilevit_v1 (self ):
4968
+ @unittest .skip ("Only outputs good accuracy in QNN 2.29" )
4969
+ def test_mobilevit_v2 (self ):
4955
4970
if not self .required_envs ([self .image_dataset ]):
4956
4971
self .skipTest ("missing required envs" )
4957
4972
4958
4973
cmds = [
4959
4974
"python" ,
4960
- f"{ self .executorch_root } /examples/qualcomm/oss_scripts/mobilevit_v1 .py"
4975
+ f"{ self .executorch_root } /examples/qualcomm/oss_scripts/mobilevit_v2 .py" ,
4961
4976
"--dataset" ,
4962
4977
self .image_dataset ,
4963
4978
"--artifact" ,
@@ -4975,6 +4990,8 @@ def test_mobilevit_v1(self):
4975
4990
]
4976
4991
if self .host :
4977
4992
cmds .extend (["--host" , self .host ])
4993
+ if self .shared_buffer :
4994
+ cmds .extend (["--shared_buffer" ])
4978
4995
4979
4996
p = subprocess .Popen (cmds , stdout = subprocess .DEVNULL )
4980
4997
with Listener ((self .ip , self .port )) as listener :
@@ -4984,17 +5001,16 @@ def test_mobilevit_v1(self):
4984
5001
if "Error" in msg :
4985
5002
self .fail (msg ["Error" ])
4986
5003
else :
4987
- self .assertGreaterEqual (msg ["top_1" ], 70 )
5004
+ self .assertGreaterEqual (msg ["top_1" ], 50 )
4988
5005
self .assertGreaterEqual (msg ["top_5" ], 85 )
4989
5006
4990
- @unittest .skip ("Only outputs good accuracy in QNN 2.29" )
4991
- def test_mobilevit_v2 (self ):
5007
+ def test_mobilevit1 (self ):
4992
5008
if not self .required_envs ([self .image_dataset ]):
4993
5009
self .skipTest ("missing required envs" )
4994
5010
4995
5011
cmds = [
4996
5012
"python" ,
4997
- f"{ self .executorch_root } /examples/qualcomm/oss_scripts/mobilevit_v2 .py" ,
5013
+ f"{ self .executorch_root } /examples/qualcomm/oss_scripts/mobilevit1 .py" ,
4998
5014
"--dataset" ,
4999
5015
self .image_dataset ,
5000
5016
"--artifact" ,
@@ -5012,8 +5028,6 @@ def test_mobilevit_v2(self):
5012
5028
]
5013
5029
if self .host :
5014
5030
cmds .extend (["--host" , self .host ])
5015
- if self .shared_buffer :
5016
- cmds .extend (["--shared_buffer" ])
5017
5031
5018
5032
p = subprocess .Popen (cmds , stdout = subprocess .DEVNULL )
5019
5033
with Listener ((self .ip , self .port )) as listener :
@@ -5023,7 +5037,7 @@ def test_mobilevit_v2(self):
5023
5037
if "Error" in msg :
5024
5038
self .fail (msg ["Error" ])
5025
5039
else :
5026
- self .assertGreaterEqual (msg ["top_1" ], 50 )
5040
+ self .assertGreaterEqual (msg ["top_1" ], 70 )
5027
5041
self .assertGreaterEqual (msg ["top_5" ], 85 )
5028
5042
5029
5043
def test_pvt (self ):
@@ -5033,7 +5047,11 @@ def test_pvt(self):
5033
5047
cmds = [
5034
5048
"python" ,
5035
5049
f"{ self .executorch_root } /examples/qualcomm/oss_scripts/pvt.py" ,
5050
+ "--dataset" ,
5036
5051
self .image_dataset ,
5052
+ "--artifact" ,
5053
+ self .artifact_dir ,
5054
+ "--build_folder" ,
5037
5055
self .build_folder ,
5038
5056
"--device" ,
5039
5057
self .device ,
0 commit comments