@@ -2110,6 +2110,102 @@ def call_operator(
21102110 return super ().call_operator (op , args , kwargs , meta )
21112111
21122112
2113+ @register_cadence_pass (CadencePassAttribute (opt_level = 2 ))
2114+ class ReplaceGeluWithApproximateGeluPass (ExportPass ):
2115+ """
2116+ Replace the gelu op with an approximate gelu op. The approximate gelu op
2117+ is more efficient on DSP backends.
2118+ """
2119+
2120+ def call_operator (
2121+ self ,
2122+ op ,
2123+ args : Tuple [Argument , ...],
2124+ kwargs : Dict [str , Argument ],
2125+ meta : NodeMetadata ,
2126+ ) -> ProxyValue :
2127+ if op not in {
2128+ exir_ops .edge .aten .gelu .default ,
2129+ }:
2130+ return super ().call_operator (op , args , kwargs , meta )
2131+
2132+ # compute the approximate gelu (0.7978845608028654 is sqrt(2 / pi))
2133+ # as 0.5 * x * (1 + torch.tanh(0.7978845608028654 * ( x + 0.044715 * x^3)))
2134+
2135+ # Get 0.5 * x
2136+ half = super ().call_operator (
2137+ exir_ops .edge .aten .mul .Tensor ,
2138+ (args [0 ], 0.5 ),
2139+ {},
2140+ meta ,
2141+ )
2142+
2143+ scaled = super ().call_operator (
2144+ exir_ops .edge .aten .mul .Tensor ,
2145+ (args [0 ], 0.044715 ),
2146+ {},
2147+ meta ,
2148+ )
2149+
2150+ # Get x^2 (note that we use mul.Tensor twice instead of pow.Tensor because
2151+ # it is much more efficient on DSP backends)
2152+ scaled_square = super ().call_operator (
2153+ exir_ops .edge .aten .mul .Tensor ,
2154+ (scaled , args [0 ]),
2155+ {},
2156+ meta ,
2157+ )
2158+
2159+ # Get x^3
2160+ scaled_cubed = super ().call_operator (
2161+ exir_ops .edge .aten .mul .Tensor ,
2162+ (scaled_square , args [0 ]),
2163+ {},
2164+ meta ,
2165+ )
2166+
2167+ # Get x + 0.044715 * x^3
2168+ inner_sum = super ().call_operator (
2169+ exir_ops .edge .aten .add .Tensor ,
2170+ (scaled_cubed , args [0 ]),
2171+ {},
2172+ meta ,
2173+ )
2174+
2175+ # Get 0.7978845608028654 * ( x + 0.044715 * x^3)
2176+ scaled_sum = super ().call_operator (
2177+ exir_ops .edge .aten .mul .Tensor ,
2178+ (inner_sum , 0.7978845608028654 ),
2179+ {},
2180+ meta ,
2181+ )
2182+
2183+ # Get torch.tanh(0.7978845608028654 * ( x + 0.044715 * x^3))
2184+ tanh = super ().call_operator (
2185+ exir_ops .edge .aten .tanh .default ,
2186+ (scaled_sum ,),
2187+ {},
2188+ meta ,
2189+ )
2190+
2191+ # Get 1 + torch.tanh(0.79788456 * ( x + 0.044715 * x^3))
2192+ # TODO(): Check why this is not working properly with integer values (e.g. 1 instead of 1.)
2193+ outer_sum = super ().call_operator (
2194+ exir_ops .edge .aten .add .Tensor ,
2195+ (tanh , 1.0 ),
2196+ {},
2197+ meta ,
2198+ )
2199+
2200+ # Retunr the final result
2201+ return super ().call_operator (
2202+ exir_ops .edge .aten .mul .Tensor ,
2203+ (half , outer_sum ),
2204+ {},
2205+ meta ,
2206+ )
2207+
2208+
21132209# This class encapsulates all the functions that replace/switch one op in the
21142210# graph with another.
21152211class CadenceReplaceOpsInGraph :
@@ -2149,4 +2245,5 @@ class CadenceReplaceOpsInGraph:
21492245 ReplaceAtenAvgPoolWithJarvisAvgPoolPass ,
21502246 ReplaceAtenLinalgVectorNormWithCadenceLinalgVectorNormPass ,
21512247 ReplaceWhereWithFullArgsWithWhereScalar ,
2248+ # ReplaceGeluWithApproximateGeluPass,
21522249 ]
0 commit comments