@@ -272,9 +272,24 @@ def test_qnn_backend_cos(self):
272272 self .lower_module_and_test_output (module , sample_input )
273273
274274 def test_qnn_backend_cumsum (self ):
275- module = CumSum () # noqa: F405
276- sample_input = (torch .randn (4 ),)
277- self .lower_module_and_test_output (module , sample_input )
275+ sample_input = ()
276+ test_comb = [
277+ {
278+ QCOM_MODULE : [CumSum ()], # noqa: F405
279+ QCOM_SAMPLE_INPUTS : [
280+ (torch .randn (4 ),),
281+ (torch .randint (0 , 10 , size = (4 ,)),),
282+ ],
283+ }
284+ ]
285+
286+ index = 0
287+ for comb in test_comb :
288+ for module in comb [QCOM_MODULE ]:
289+ for sample_input in comb [QCOM_SAMPLE_INPUTS ]:
290+ with self .subTest (i = index ):
291+ self .lower_module_and_test_output (module , sample_input )
292+ index += 1
278293
279294 def test_qnn_backend_einsum_outer_product (self ):
280295 module = EinsumOuterProduct () # noqa: F405
@@ -311,6 +326,12 @@ def test_qnn_backend_element_wise_add(self):
311326 QCOM_MODULE : [AddConstantFloat ()], # noqa: F405
312327 QCOM_SAMPLE_INPUTS : [(torch .randn (2 , 5 , 1 , 3 ),)],
313328 },
329+ {
330+ QCOM_MODULE : [
331+ AddConstantLong (), # noqa: F405
332+ ],
333+ QCOM_SAMPLE_INPUTS : [(torch .randint (0 , 10 , size = (2 , 3 )),)],
334+ },
314335 ]
315336
316337 index = 0
@@ -4526,6 +4547,40 @@ def test_retinanet(self):
45264547 else :
45274548 self .assertGreaterEqual (msg ["mAP" ], 0.6 )
45284549
4550+ def test_roberta (self ):
4551+ if not self .required_envs ([self .sentence_dataset ]):
4552+ self .skipTest ("missing required envs" )
4553+ cmds = [
4554+ "python" ,
4555+ f"{ self .executorch_root } /examples/qualcomm/oss_scripts/roberta.py" ,
4556+ "--dataset" ,
4557+ self .sentence_dataset ,
4558+ "--artifact" ,
4559+ self .artifact_dir ,
4560+ "--build_folder" ,
4561+ self .build_folder ,
4562+ "--device" ,
4563+ self .device ,
4564+ "--model" ,
4565+ self .model ,
4566+ "--ip" ,
4567+ self .ip ,
4568+ "--port" ,
4569+ str (self .port ),
4570+ ]
4571+ if self .host :
4572+ cmds .extend (["--host" , self .host ])
4573+
4574+ p = subprocess .Popen (cmds , stdout = subprocess .DEVNULL )
4575+ with Listener ((self .ip , self .port )) as listener :
4576+ conn = listener .accept ()
4577+ p .communicate ()
4578+ msg = json .loads (conn .recv ())
4579+ if "Error" in msg :
4580+ self .fail (msg ["Error" ])
4581+ else :
4582+ self .assertGreaterEqual (msg ["accuracy" ], 0.5 )
4583+
45294584 def test_squeezenet (self ):
45304585 if not self .required_envs ([self .image_dataset ]):
45314586 self .skipTest ("missing required envs" )
@@ -5344,6 +5399,11 @@ def setup_environment():
53445399 help = "Location for imagenet dataset" ,
53455400 type = str ,
53465401 )
5402+ parser .add_argument (
5403+ "--sentence_dataset" ,
5404+ help = "Location for sentence dataset" ,
5405+ type = str ,
5406+ )
53475407 parser .add_argument (
53485408 "-p" ,
53495409 "--pretrained_weight" ,
@@ -5402,6 +5462,7 @@ def setup_environment():
54025462 TestQNN .executorch_root = args .executorch_root
54035463 TestQNN .artifact_dir = args .artifact_dir
54045464 TestQNN .image_dataset = args .image_dataset
5465+ TestQNN .sentence_dataset = args .sentence_dataset
54055466 TestQNN .pretrained_weight = args .pretrained_weight
54065467 TestQNN .model_name = args .model_name
54075468 TestQNN .online_prepare = args .online_prepare
0 commit comments