Skip to content

Commit 011ee70

Browse files
committed
Update on "Reuse GELU implementation from PyTorch core"
kernels/optimized doesn't need to support embedded systems, so it can just take a header-only dep on PyTorch. Note that, because we will pick up Sleef internally and ignore it externally thanks to ATen vec, this PR gets to enable optimized GELU in OSS. Testing: CI to make sure this doesn't break mobile build modes; happy to take advice on anything not currently covered that might break. Differential Revision: [D66335522](https://our.internmc.facebook.com/intern/diff/D66335522/) [ghstack-poisoned]
2 parents b4f0809 + c99bf42 commit 011ee70

File tree

2 files changed

+44
-38
lines changed

2 files changed

+44
-38
lines changed

build/build_apple_frameworks.sh

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,6 @@ echo "Exporting headers"
187187

188188
mkdir -p "$HEADERS_PATH"
189189

190-
# Set BUCK2 to the path of the buck2 executable in $OUTPUT/*/buck2-bin/buck2-*
191190
BUCK2=$(find $SOURCE_ROOT_DIR -type f -path '*/buck2-bin/buck2-*' | head -n 1)
192191
if [[ -z "$BUCK2" ]]; then
193192
echo "Could not find buck2 executable in any buck2-bin directory under $OUTPUT"
@@ -208,11 +207,11 @@ check_command "$BUCK2"
208207
# So, just patch our generated framework to do that.
209208
sed -i '' '1i\
210209
#define C10_USING_CUSTOM_GENERATED_MACROS
211-
' $SOURCE_ROOT_DIR/runtime/core/portable_type/c10/macros/Macros.h
210+
' $HEADERS_PATH/executorch/runtime/core/portable_type/c10/macros/Macros.h
212211
sed -i '' '1i\
213212
#define C10_USING_CUSTOM_GENERATED_MACROS
214-
' $SOURCE_ROOT_DIR/runtime/core/portable_type/c10/macros/Export.h
215-
cp -r $SOURCE_ROOT_DIR/runtime/core/portable_type/c10 "$HEADERS_PATH/"
213+
' $HEADERS_PATH/executorch/runtime/core/portable_type/c10/macros/Export.h
214+
ln -s $HEADERS_PATH/executorch/runtime/core/portable_type/c10 "$HEADERS_PATH/"
216215

217216

218217
cp "$SOURCE_ROOT_DIR/extension/apple/ExecuTorch/Exported/"*.h "$HEADERS_PATH/executorch"

runtime/core/portable_type/c10/util/BFloat16-math.h

Lines changed: 41 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@ C10_CLANG_DIAGNOSTIC_PUSH()
88
C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion")
99
#endif
1010

11-
namespace std {
12-
11+
namespace c10 {
1312
template <typename T>
1413
struct is_reduced_floating_point
1514
: std::integral_constant<
@@ -19,193 +18,201 @@ struct is_reduced_floating_point
1918
template <typename T>
2019
constexpr bool is_reduced_floating_point_v =
2120
is_reduced_floating_point<T>::value;
21+
} // namespace c10
22+
23+
namespace std {
24+
25+
#if !defined(FBCODE_CAFFE2) && !defined(C10_NODEPRECATED)
26+
using c10::is_reduced_floating_point;
27+
using c10::is_reduced_floating_point_v;
28+
#endif // !defined(FBCODE_CAFFE2) && !defined(C10_NODEPRECATED)
2229

2330
template <
2431
typename T,
25-
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
32+
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
2633
inline T acos(T a) {
2734
return std::acos(float(a));
2835
}
2936
template <
3037
typename T,
31-
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
38+
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
3239
inline T asin(T a) {
3340
return std::asin(float(a));
3441
}
3542
template <
3643
typename T,
37-
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
44+
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
3845
inline T atan(T a) {
3946
return std::atan(float(a));
4047
}
4148
template <
4249
typename T,
43-
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
50+
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
4451
inline T atanh(T a) {
4552
return std::atanh(float(a));
4653
}
4754
template <
4855
typename T,
49-
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
56+
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
5057
inline T erf(T a) {
5158
return std::erf(float(a));
5259
}
5360
template <
5461
typename T,
55-
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
62+
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
5663
inline T erfc(T a) {
5764
return std::erfc(float(a));
5865
}
5966
template <
6067
typename T,
61-
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
68+
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
6269
inline T exp(T a) {
6370
return std::exp(float(a));
6471
}
6572
template <
6673
typename T,
67-
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
74+
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
6875
inline T expm1(T a) {
6976
return std::expm1(float(a));
7077
}
7178
template <
7279
typename T,
73-
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
80+
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
7481
inline bool isfinite(T a) {
7582
return std::isfinite(float(a));
7683
}
7784
template <
7885
typename T,
79-
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
86+
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
8087
inline T log(T a) {
8188
return std::log(float(a));
8289
}
8390
template <
8491
typename T,
85-
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
92+
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
8693
inline T log10(T a) {
8794
return std::log10(float(a));
8895
}
8996
template <
9097
typename T,
91-
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
98+
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
9299
inline T log1p(T a) {
93100
return std::log1p(float(a));
94101
}
95102
template <
96103
typename T,
97-
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
104+
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
98105
inline T log2(T a) {
99106
return std::log2(float(a));
100107
}
101108
template <
102109
typename T,
103-
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
110+
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
104111
inline T ceil(T a) {
105112
return std::ceil(float(a));
106113
}
107114
template <
108115
typename T,
109-
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
116+
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
110117
inline T cos(T a) {
111118
return std::cos(float(a));
112119
}
113120
template <
114121
typename T,
115-
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
122+
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
116123
inline T floor(T a) {
117124
return std::floor(float(a));
118125
}
119126
template <
120127
typename T,
121-
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
128+
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
122129
inline T nearbyint(T a) {
123130
return std::nearbyint(float(a));
124131
}
125132
template <
126133
typename T,
127-
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
134+
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
128135
inline T sin(T a) {
129136
return std::sin(float(a));
130137
}
131138
template <
132139
typename T,
133-
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
140+
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
134141
inline T tan(T a) {
135142
return std::tan(float(a));
136143
}
137144
template <
138145
typename T,
139-
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
146+
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
140147
inline T sinh(T a) {
141148
return std::sinh(float(a));
142149
}
143150
template <
144151
typename T,
145-
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
152+
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
146153
inline T cosh(T a) {
147154
return std::cosh(float(a));
148155
}
149156
template <
150157
typename T,
151-
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
158+
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
152159
inline T tanh(T a) {
153160
return std::tanh(float(a));
154161
}
155162
template <
156163
typename T,
157-
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
164+
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
158165
inline T trunc(T a) {
159166
return std::trunc(float(a));
160167
}
161168
template <
162169
typename T,
163-
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
170+
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
164171
inline T lgamma(T a) {
165172
return std::lgamma(float(a));
166173
}
167174
template <
168175
typename T,
169-
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
176+
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
170177
inline T sqrt(T a) {
171178
return std::sqrt(float(a));
172179
}
173180
template <
174181
typename T,
175-
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
182+
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
176183
inline T rsqrt(T a) {
177184
return 1.0 / std::sqrt(float(a));
178185
}
179186
template <
180187
typename T,
181-
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
188+
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
182189
inline T abs(T a) {
183190
return std::abs(float(a));
184191
}
185192
#if defined(_MSC_VER) && defined(__CUDACC__)
186193
template <
187194
typename T,
188-
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
195+
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
189196
inline T pow(T a, double b) {
190197
return std::pow(float(a), float(b));
191198
}
192199
#else
193200
template <
194201
typename T,
195-
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
202+
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
196203
inline T pow(T a, double b) {
197204
return std::pow(float(a), b);
198205
}
199206
#endif
200207
template <
201208
typename T,
202-
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
209+
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
203210
inline T pow(T a, T b) {
204211
return std::pow(float(a), float(b));
205212
}
206213
template <
207214
typename T,
208-
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
215+
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
209216
inline T fmod(T a, T b) {
210217
return std::fmod(float(a), float(b));
211218
}
@@ -238,7 +245,7 @@ inline T fmod(T a, T b) {
238245
*/
239246
template <
240247
typename T,
241-
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
248+
typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
242249
C10_HOST_DEVICE inline T nextafter(T from, T to) {
243250
// Reference:
244251
// https://git.musl-libc.org/cgit/musl/tree/src/math/nextafter.c

0 commit comments

Comments
 (0)