1818
1919@register_node_visitor
2020class Pad (NodeVisitor ):
21- target = ["aten.constant_pad_nd.default" ]
21+ target = [
22+ "aten.constant_pad_nd.default" ,
23+ "aten.pad.default" , # handles reflect/replicate modes
24+ "aten.reflection_pad2d.default" ,
25+ "aten.replication_pad2d.default" ,
26+ ]
2227
2328 def __init__ (self , * args ) -> None :
2429 super ().__init__ (* args )
@@ -28,6 +33,8 @@ def define_node(
2833 node : torch .fx .Node ,
2934 nodes_to_wrappers : Dict [torch .fx .Node , PyQnnWrapper .TensorWrapper ],
3035 ) -> PyQnnWrapper .PyQnnOpWrapper :
36+
37+ # ---- Input tensor ----
3138 input_node = self .get_node (node .args [0 ])
3239 input_tensor = self .get_tensor (input_node , node )
3340 pad_inp_tensor_wrapper = self .define_tensor (
@@ -39,6 +46,7 @@ def define_node(
3946 )
4047 pad_input_tensors = [pad_inp_tensor_wrapper ]
4148
49+ # ---- Output tensor ----
4250 output_tensor = self .get_tensor (node , node )
4351 output_tensor_wrapper = self .define_tensor (
4452 node ,
@@ -49,21 +57,43 @@ def define_node(
4957 )
5058 pad_output_tensors = [output_tensor_wrapper ]
5159
60+ # ---- Pad amount handling ----
61+ pad_list = cast (List [int ], node .args [1 ])
5262 pad_amount_shape = [input_tensor .dim (), 2 ]
53- # pytorch padding start from the last index
54- pad_amount = np . reshape ( cast ( List [ int ], node . args [ 1 ]), ( - 1 , 2 ))[:: - 1 ]. astype (
55- np .uint32
56- )
57- # fulfill the pad amount for each idex of tensor
63+
64+ # PyTorch pad order: [last_dim, ..., first_dim]
65+ pad_amount = np . reshape ( pad_list , ( - 1 , 2 ))[:: - 1 ]. astype ( np .uint32 )
66+
67+ # Expand to full rank if needed
5868 if zero_amounts := pad_amount_shape [0 ] - pad_amount .shape [0 ]:
5969 pad_amount = np .concatenate (
6070 (np .array ([(0 , 0 )] * zero_amounts ), pad_amount )
6171 ).astype (np .uint32 )
6272
73+ # Apply axis reordering if necessary
6374 if QCOM_AXIS_ORDER in node .meta :
6475 pad_amount = pad_amount [list (node .meta [QCOM_AXIS_ORDER ])]
65- pad_amount_val = node .args [2 ]
6676
77+ # ---- Determine mode ----
78+ if len (node .args ) >= 3 and isinstance (node .args [2 ], str ):
79+ mode = node .args [2 ]
80+ elif "reflection" in node .target :
81+ mode = "reflect"
82+ elif "replication" in node .target :
83+ mode = "replicate"
84+ else :
85+ mode = "constant"
86+
87+ scheme_map = {
88+ "constant" : OpPad .Scheme .CONSTANT ,
89+ "reflect" : OpPad .Scheme .MIRROR_REFLECT ,
90+ "replicate" : OpPad .Scheme .EDGE ,
91+ }
92+
93+ if mode not in scheme_map :
94+ raise ValueError (f"[QNN][Pad] Unsupported pad mode: { mode } " )
95+
96+ # ---- Create QNN op ----
6797 pad_op = PyQnnWrapper .PyQnnOpWrapper (
6898 node .name ,
6999 QNN_OP_PACKAGE_NAME_QTI_AISW ,
@@ -72,19 +102,29 @@ def define_node(
72102 pad_op .AddInputTensors (pad_input_tensors )
73103 pad_op .AddOutputTensors (pad_output_tensors )
74104
75- # For now, we only support constant (0) padding due to torch implementation
105+ # scheme param
76106 pad_op .AddScalarParam (
77107 OpPad .param_scheme ,
78108 PyQnnWrapper .Qnn_DataType_t .QNN_DATATYPE_UINT_32 ,
79- {QCOM_DATA : np .uint32 (OpPad . Scheme . CONSTANT )},
109+ {QCOM_DATA : np .uint32 (scheme_map [ mode ] )},
80110 )
81111
82- pad_op .AddScalarParam (
83- OpPad .param_pad_constant_value ,
84- QNN_TENSOR_TYPE_MAP [type (pad_amount_val )],
85- {QCOM_DATA : pad_amount_val },
86- )
112+ # pad_constant_value param (only for constant mode)
113+ if mode == "constant" :
114+ # torch.constant_pad_nd takes optional pad value, default = 0.0
115+ pad_value = node .kwargs .get ("value" , None )
116+ if pad_value is None and len (node .args ) > 2 and not isinstance (node .args [2 ], str ):
117+ pad_value = node .args [2 ]
118+ if pad_value is None :
119+ pad_value = 0.0
120+
121+ pad_op .AddScalarParam (
122+ OpPad .param_pad_constant_value ,
123+ QNN_TENSOR_TYPE_MAP [type (pad_value )],
124+ {QCOM_DATA : pad_value },
125+ )
87126
127+ # pad_amount tensor param
88128 pad_op .AddTensorParam (
89129 OpPad .param_pad_amount ,
90130 PyQnnWrapper .Qnn_DataType_t .QNN_DATATYPE_UINT_32 ,
0 commit comments