4040)
4141
4242from executorch .backends .arm .vgf_partitioner import VgfPartitioner
43+
44+ # To use Cortex-M backend
45+ from executorch .backends .cortex_m .passes .replace_quant_nodes_pass import (
46+ ReplaceQuantNodesPass ,
47+ )
4348from executorch .devtools .backend_debug import get_delegation_info
4449from executorch .devtools .bundled_program .config import MethodTestCase , MethodTestSuite
4550
5964from ..models import MODEL_NAME_TO_MODEL
6065from ..models .model_factory import EagerModelFactory
6166
67+
6268FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
6369logging .basicConfig (level = logging .WARNING , format = FORMAT )
6470
@@ -216,6 +222,54 @@ def forward(self, x, y):
216222 can_delegate = True
217223
218224
225+ class QuantAddTest (torch .nn .Module ):
226+ def __init__ (self ):
227+ super ().__init__ ()
228+
229+ def forward (self , a ):
230+ return a + a
231+
232+ example_input = (torch .rand ([13 , 3 ], dtype = torch .float32 ),) # a - normal values
233+ can_delegate = True # when quantized
234+
235+
236+ class QuantAddTest2 (torch .nn .Module ):
237+ def __init__ (self ):
238+ super ().__init__ ()
239+
240+ def forward (self , a , b ):
241+ p = a + a
242+ q = b + b
243+ r = p + q
244+ return p , q , r
245+
246+ example_input = (
247+ torch .randn ([13 , 7 , 3 ], dtype = torch .float32 ),
248+ torch .randn ([13 , 7 , 3 ], dtype = torch .float32 ),
249+ )
250+ can_delegate = True # when quantized
251+
252+
253+ class QuantOpTest (torch .nn .Module ):
254+ def __init__ (self ):
255+ super ().__init__ ()
256+
257+ def forward (self , w , x , y , z ):
258+ o1 = w - x
259+ o2 = o1 + y
260+ o3 = o2 * z
261+ return o1 , o2 , o3
262+
263+ example_input = (
264+ torch .randn ([3 , 1 , 2 ], dtype = torch .float32 ), # w - normal values
265+ torch .randn ([3 , 5 , 2 ], dtype = torch .float32 ), # x - normal values
266+ torch .randn ([3 , 5 , 1 ], dtype = torch .float32 )
267+ * - 0.000001 , # y - small -ve values, needs to be calibration for tests
268+ torch .randn ([3 , 5 , 2 ], dtype = torch .float32 ) * 1000 , # z - large values
269+ )
270+ can_delegate = True # when quantized
271+
272+
219273class SoftmaxModule (torch .nn .Module ):
220274 def __init__ (self ):
221275 super ().__init__ ()
@@ -241,6 +295,9 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
241295 "add" : AddModule ,
242296 "add2" : AddModule2 ,
243297 "add3" : AddModule3 ,
298+ "qadd" : QuantAddTest ,
299+ "qadd2" : QuantAddTest2 ,
300+ "qops" : QuantOpTest ,
244301 "softmax" : SoftmaxModule ,
245302 "MultipleOutputsModule" : MultipleOutputsModule ,
246303}
@@ -255,6 +312,17 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
255312 torch .randn (32 , 5 ),
256313 torch .randn (32 , 5 ),
257314 ),
315+ "qadd" : (torch .randn (32 , 2 , 1 ),),
316+ "qadd2" : (
317+ torch .randn (32 , 2 , 1 ),
318+ torch .randn (32 , 2 , 1 ),
319+ ),
320+ "qops" : (
321+ torch .randn (32 , 2 , 1 ),
322+ torch .randn (32 , 2 , 1 ),
323+ torch .randn (32 , 2 , 1 ) * - 0.000001 ,
324+ torch .randn (32 , 2 , 1 ) * 1000 ,
325+ ),
258326 "softmax" : (torch .randn (32 , 2 , 2 ),),
259327}
260328
@@ -656,6 +724,7 @@ def to_edge_TOSA_delegate(
656724 _check_ir_validity = False ,
657725 ),
658726 )
727+
659728 return model_int8 , edge
660729
661730
@@ -681,9 +750,18 @@ def to_edge_no_delegate(exported_program, args, model: torch.nn.Module, example_
681750 _check_ir_validity = False ,
682751 ),
683752 )
753+
684754 return model_int8 , edge
685755
686756
757+ def transform_for_cortex_m_backend (edge ):
758+ # Let's make sure we are using optimized Cortex M backend
759+ # NB: If we can't find and replace ops those are expected to be replaced,
760+ # bad things will happen at runtime, like "missing operator" errors!
761+ edge = edge .transform ([ReplaceQuantNodesPass ()])
762+ return edge
763+
764+
687765if __name__ == "__main__" : # noqa: C901
688766 args = get_args ()
689767
@@ -715,6 +793,9 @@ def to_edge_no_delegate(exported_program, args, model: torch.nn.Module, example_
715793 exported_program , args , model , example_inputs
716794 )
717795
796+ # Transform so we can use ops from the Cortex M backend
797+ edge = transform_for_cortex_m_backend (edge )
798+
718799 dump_delegation_info (edge , args .intermediates )
719800
720801 try :
0 commit comments