2424from executorch .exir .passes import dead_code_elimination_pass
2525from executorch .exir .passes .scalar_to_tensor_pass import ScalarToTensorPass
2626from executorch .exir .passes .spec_prop_pass import SpecPropPass
27- from torch ._subclasses import FakeTensor
28- from torch .utils ._pytree import tree_map_only
2927
3028
3129@register_cadence_pass (CadencePassAttribute (opt_level = 0 ))
@@ -76,7 +74,9 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
7674@register_cadence_pass (CadencePassAttribute (opt_level = 0 ))
7775class ReplacePT2QuantWithCadenceQuantPass (ExportPass ):
7876 """
79- Replace the pt2 quantization ops with custom cadence quantization ops.
77+ Replace the pt2 quantization ops with jarvis quantization ops.
78+ We do not link kernels to the PT2 quantization ops, so we need to
79+ replace them with jarvis ops at all optimization levels.
8080 """
8181
8282 def call_operator (
@@ -100,7 +100,9 @@ def call_operator(
100100@register_cadence_pass (CadencePassAttribute (opt_level = 0 ))
101101class ReplacePT2DequantWithCadenceDequantPass (ExportPass ):
102102 """
103- Replace the pt2 dequantization ops with custom cadence dequantization ops.
103+ Replace the pt2 dequantization ops with jarvis dequantization ops.
104+ We do not link kernels to the PT2 quantization ops, so we need to
105+ replace them with jarvis ops at all optimization levels.
104106 """
105107
106108 def call_operator (
@@ -188,49 +190,44 @@ def call_operator(
188190
189191
190192@register_cadence_pass (CadencePassAttribute (opt_level = 0 ))
191- class RemoveZeroSizedCatArgsPass (ExportPass ): # is this the latest?
193+ class RemoveZeroSizedCatArgsPass (ExportPass ):
192194 def call_operator (
193195 self ,
194196 op , # pyre-ignore
195- args : Tuple [Argument , ...],
196- kwargs : Dict [str , Argument ],
197+ args : tuple [Argument , ...],
198+ kwargs : dict [str , Argument ],
197199 meta : NodeMetadata ,
198200 ) -> ProxyValue :
199201 if op != exir_ops .edge .aten .cat .default :
200202 return super ().call_operator (op , args , kwargs , meta )
201203
202204 # Remove any zero-sized tensor arg to form a new args list.
203- new_args = []
204- for arg in args [0 ]:
205- arg_tensor = arg .to_tensor () if isinstance (arg , ProxyValue ) else arg
206- if arg_tensor .numel () > 0 :
207- new_args .append (arg )
205+ cat_inputs : list [ProxyValue ] = []
206+ for arg in cast (Sequence [ProxyValue ], args [0 ]):
207+ if arg .to_tensor ().numel () > 0 :
208+ cat_inputs .append (arg )
208209
209210 # If all the tensors were empty, we just return an empty tensor with
210211 # the right shape.
211- if not new_args :
212- args_data , kwargs_data = tree_map_only (
213- ProxyValue , lambda x : x .data , (args , kwargs )
212+ if not cat_inputs :
213+ empty_shape = meta ["val" ].shape
214+ dtype = meta ["val" ].dtype
215+ return super ().call_operator (
216+ exir_ops .edge .aten .full .default ,
217+ (tuple (empty_shape ), 0 ),
218+ {"dtype" : dtype },
219+ meta ,
214220 )
215- result = op (* args_data , ** kwargs_data )
216- # When tracing with PT2, the FakeTensor mode requires the constant
217- # argument to be set to itself.
218- # TODO(matthiascremon): confirm this is the best way to do this.
219- if isinstance (result , FakeTensor ):
220- result .constant = result
221- # pyre-ignore[7]: Incompatible return type.
222- return torch .empty_like (result )
223-
224- # If there was only one tensor in the new_args list,
221+
222+ # If there was only one tensor in the cat_inputs list,
225223 # we can safely erase this cat op.
226- if len (new_args ) == 1 :
227- return new_args [0 ]
224+ if len (cat_inputs ) == 1 :
225+ return cat_inputs [0 ]
228226
229- # Otherwise, we replace args[0] with new_args.
230- init_args = list (args )
231- init_args [0 ] = new_args
232- args = tuple (args )
233- return super ().call_operator (op , args , kwargs , meta )
227+ # Otherwise, we replace args[0] with cat_inputs.
228+ new_args = list (args )
229+ new_args [0 ] = cat_inputs
230+ return super ().call_operator (op , tuple (new_args ), kwargs , meta )
234231
235232
236233@register_cadence_pass (CadencePassAttribute (opt_level = 0 ))
0 commit comments