1212import torch
1313import torch .fx
1414from executorch .backends .arm .operators .node_visitor import NodeVisitor
15- from executorch .backends .arm .tosa_mapping import TosaArg
15+ from executorch .backends .arm .tosa_mapping import map_dtype , TosaArg
16+ from executorch .backends .arm .tosa_quant_utils import (
17+ dq_op ,
18+ get_quantized_node_output_dtype ,
19+ is_node_quantized ,
20+ )
1621from executorch .backends .arm .tosa_specification import TosaSpecification
1722from executorch .backends .arm .tosa_utils import getNodeArgs , tosa_shape
1823from torch .export .exported_program import ExportedProgram
@@ -30,8 +35,15 @@ def process_call_function(
3035 # Convert output (this node itself)
3136 output = TosaArg (node )
3237
38+ is_dq_node = node .target == dq_op
39+ if is_dq_node :
40+ output_dtype = ts .DType .INT8
41+ else :
42+ output_dtype = output .dtype
3343 tosa_graph .currRegion .currBasicBlock .addTensor (
34- output .name , tosa_shape (output .shape , output .dim_order ), output .dtype
44+ output .name ,
45+ tosa_shape (output .shape , output .dim_order ),
46+ output_dtype ,
3547 )
3648
3749 # Visiting each Node
@@ -67,7 +79,11 @@ def process_inputs(
6779 tensor = ts .TosaSerializerTensor (
6880 inputs [0 ].name ,
6981 tosa_shape (input_shape , input_dim_order ),
70- inputs [0 ].dtype ,
82+ (
83+ map_dtype (get_quantized_node_output_dtype (node ))
84+ if is_node_quantized (node )
85+ else inputs [0 ].dtype
86+ ),
7187 data = None ,
7288 placeholderFilename = inputs [0 ].name + ".npy" ,
7389 )
0 commit comments