Skip to content

Commit f3e4f2a

Browse files
mcfifacebook-github-bot
authored andcommitted
Enable arm64 convolution for fbgemm through the reference convolution APIs (#5126)
Summary: X-link: facebookresearch/FBGEMM#2128 This diff adds convolution support to arm64 fbgemm by reusing existing reference implementations. 1. Introduced conv_requant_ref that invokes the reference conv_ref and requantize_u8acc32_ref and added it in places where only x86 conv implementations are available. 2. Changed weights matrix packing to basically do nothing or call transposeConvWeights. This diff unblocks fbgemm users' convolution code on Arm64. We plan to add follow-up diffs to optimize each kind of convolution (e.g., depthwise, directconv, etc.) Differential Revision: D86548699
1 parent 1fd545d commit f3e4f2a

16 files changed

+757
-84
lines changed

bench/EmbeddingSpMDMNBitBenchmark.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,8 @@ static int run_benchmark(
206206
/*output_bit_rate=*/-1);
207207
#endif
208208

209-
vector<OutType>& output = has_weight ? output_slws : output_sls;
209+
[[maybe_unused]] vector<OutType>& output =
210+
has_weight ? output_slws : output_sls;
210211
for (bool flush_cache : {false, true}) {
211212
bool success_ref = false;
212213
// Reference implementation

include/fbgemm/Fbgemm.h

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -161,11 +161,7 @@ class PackMatrix {
161161
* @brief Actual packing of a block of the source matrix in pmat buffer.
162162
*/
163163
void pack(const block_type_t& block) {
164-
#if defined(FBGEMM_FBCODE) || !defined(__aarch64__)
165164
static_cast<PT*>(this)->pack(block);
166-
#else
167-
throw std::runtime_error("PackMatrix::pack() not implemented for aarch64");
168-
#endif // __aarch64__
169165
}
170166

171167
std::int32_t numRows() const {
@@ -616,11 +612,9 @@ class FBGEMM_API PackWeightsForConv {
616612
return W_im2col_packed_;
617613
}
618614

619-
#if defined(FBGEMM_FBCODE) || !defined(__aarch64__)
620615
std::shared_ptr<PackedDepthWiseConvMatrix> getPackedWForDepthwise() {
621616
return W_dw_packed_;
622617
}
623-
#endif // __aarch64__
624618

625619
std::shared_ptr<PackedDirectConvMatrix> getPackedWForDirectconv() {
626620
return W_dc_packed_;
@@ -672,10 +666,8 @@ class FBGEMM_API PackWeightsForConv {
672666
const conv_param_t<SPATIAL_DIM> conv_param_;
673667
// Packed weights if we use im2col based convolution implementation
674668
std::shared_ptr<PackBMatrix<T, accT>> W_im2col_packed_;
675-
#if defined(FBGEMM_FBCODE) || !defined(__aarch64__)
676669
// Packed weights if we use depthwise convolution implementation
677670
std::shared_ptr<PackedDepthWiseConvMatrix> W_dw_packed_;
678-
#endif // __aarch64__
679671
// Packed weights if we use direct convolution implementation
680672
std::shared_ptr<PackedDirectConvMatrix> W_dc_packed_;
681673
// Packed weights if we use groupwise (small channels per group) convolution

src/Fbgemm.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,8 @@ void fbgemmPacked(
203203

204204
template <int SPATIAL_DIM>
205205
bool fbgemmOptimizedGConv(const conv_param_t<SPATIAL_DIM>& conv_p) {
206+
#if defined(__x86_64__) || defined(__i386__) || \
207+
(defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86)))
206208
if constexpr (SPATIAL_DIM == 1)
207209
return false;
208210

@@ -255,6 +257,9 @@ bool fbgemmOptimizedGConv(const conv_param_t<SPATIAL_DIM>& conv_p) {
255257
return areEqual(std::forward<decltype(PH1)>(PH1), 2);
256258
})) &&
257259
!conv_p.transposed;
260+
#else
261+
return false;
262+
#endif
258263
}
259264

260265
template FBGEMM_API bool fbgemmOptimizedGConv(const conv_param_t<1>& conv_p);

src/FbgemmConv.cc

Lines changed: 65 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <numeric>
1313
#include <stdexcept> // for logic_error
1414
#include <vector>
15+
#include "RefImplementations.h"
1516
#include "fbgemm/Fbgemm.h"
1617

1718
namespace fbgemm {
@@ -138,10 +139,6 @@ int fbgemmConv(
138139

139140
switch (ConvFastPath<SPATIAL_DIM, ACC_T>(conv_p)) {
140141
case optimized_conv_t::depthwise: {
141-
#if defined(__aarch64__)
142-
throw std::runtime_error(
143-
"fbgemmConv<processOutputType, SPATIAL_DIM, ACC_T>(): No fallback available for aarch64");
144-
#else
145142
// 2D and 3D depthwise fast path
146143
// std::cout << "Depthwise fast path" << std::endl;
147144
if constexpr (SPATIAL_DIM == 3) {
@@ -220,7 +217,6 @@ int fbgemmConv(
220217
throw std::runtime_error(msg);
221218
}
222219
break;
223-
#endif // __aarch64__
224220
}
225221
case optimized_conv_t::groupwise: {
226222
// optimized groupwise convolution
@@ -242,6 +238,8 @@ int fbgemmConv(
242238
break;
243239
}
244240
case optimized_conv_t::pointwise: {
241+
#if defined(__x86_64__) || defined(__i386__) || \
242+
(defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86)))
245243
std::vector<int32_t> row_offset_buf(
246244
PackAWithRowOffset<uint8_t>::rowOffsetBufferSize(blocking_params));
247245
int image_dim = std::accumulate(
@@ -271,16 +269,42 @@ int fbgemmConv(
271269
thread_id,
272270
num_threads,
273271
blocking_params);
272+
#else
273+
DoNothing<> doNothingObj{};
274+
ReQuantizeOutput<
275+
processOutputType::RELU_FUSED,
276+
processOutputType::QGRANType,
277+
typename processOutputType::BIAS_T>
278+
reqObj(
279+
doNothingObj,
280+
outProcess.getCMultiplier(),
281+
outProcess.getCZeroPoint(),
282+
outProcess.getAZeroPoint(),
283+
outProcess.getBZeroPoint(),
284+
nullptr, /* row offset buffer */
285+
outProcess.getColOffsets(),
286+
outProcess.getBias(),
287+
conv_p.OC,
288+
conv_p.G,
289+
outProcess.getActWScale());
290+
291+
conv_requant_ref(
292+
conv_p,
293+
activations,
294+
packed_weights.getPackedWForPointwise()->getBuf(),
295+
false,
296+
out,
297+
outBuffer,
298+
reqObj,
299+
thread_id,
300+
num_threads);
301+
#endif
274302
break;
275303
}
276304
case optimized_conv_t::directconv: {
277305
// specialized direct convolution path
278306
// std::cout << "Directconv fast path" << std::endl;
279307
if constexpr (SPATIAL_DIM == 2) {
280-
#if defined(__aarch64__)
281-
throw std::runtime_error(
282-
"fbgemmConv<processOutputType, SPATIAL_DIM, ACC_T>(): No fallback available for aarch64");
283-
#else
284308
fbgemmDirectConv<SPATIAL_DIM, processOutputType::QGRANType>(
285309
conv_p,
286310
// Aint8,
@@ -292,7 +316,6 @@ int fbgemmConv(
292316
outProcess.getBias(),
293317
thread_id,
294318
num_threads);
295-
#endif
296319
} else {
297320
assert(false && "1d/3d direct conv not supported");
298321
}
@@ -302,6 +325,8 @@ int fbgemmConv(
302325
break;
303326
}
304327
case optimized_conv_t::im2col: {
328+
#if defined(__x86_64__) || defined(__i386__) || \
329+
(defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86)))
305330
// All other convolutions go through im2col-based implementation
306331
// std::cout << "Im2col path" << std::endl;
307332
std::vector<int32_t> row_offset_buf(
@@ -352,6 +377,36 @@ int fbgemmConv(
352377
thread_id,
353378
num_threads,
354379
blocking_params);
380+
#else
381+
DoNothing<> doNothingObj{};
382+
ReQuantizeOutput<
383+
processOutputType::RELU_FUSED,
384+
processOutputType::QGRANType,
385+
typename processOutputType::BIAS_T>
386+
reqObj(
387+
doNothingObj,
388+
outProcess.getCMultiplier(),
389+
outProcess.getCZeroPoint(),
390+
outProcess.getAZeroPoint(),
391+
outProcess.getBZeroPoint(),
392+
nullptr, /* row offset buffer */
393+
outProcess.getColOffsets(),
394+
outProcess.getBias(),
395+
conv_p.OC,
396+
conv_p.G,
397+
outProcess.getActWScale());
398+
399+
conv_requant_ref(
400+
conv_p,
401+
activations,
402+
packed_weights.getPackedWForIm2col()->getBuf(),
403+
false,
404+
out,
405+
outBuffer,
406+
reqObj,
407+
thread_id,
408+
num_threads);
409+
#endif
355410
break;
356411
}
357412
} // switch

src/FbgemmI8Depthwise3DAvx2.cc

Lines changed: 71 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -980,6 +980,9 @@ void depthwise_3d_same_pad(
980980
// In C2, batch size 0 is allowed, so we should just early return.
981981
return;
982982
}
983+
984+
#if defined(__x86_64__) || defined(__i386__) || \
985+
(defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86)))
983986
if (fuse_relu) {
984987
depthwise_3d_same_pad_<true /*FUSE_RELU*/, Q_GRAN>(
985988
conv_p,
@@ -1011,24 +1014,76 @@ void depthwise_3d_same_pad(
10111014
thread_id,
10121015
num_threads);
10131016
}
1017+
#else
1018+
DoNothing<> doNothingObj{};
1019+
if (fuse_relu) {
1020+
ReQuantizeOutput<true, Q_GRAN, BIAS_TYPE> reqObj(
1021+
doNothingObj,
1022+
C_multiplier,
1023+
C_zero_point,
1024+
A_zero_point,
1025+
B_zero_point,
1026+
nullptr, /* row offset buffer */
1027+
col_offsets,
1028+
bias,
1029+
conv_p.OC,
1030+
conv_p.G,
1031+
act_times_w_scale);
1032+
1033+
conv_requant_ref(
1034+
conv_p,
1035+
A,
1036+
B.PackedMat(),
1037+
false,
1038+
C,
1039+
nullptr,
1040+
reqObj,
1041+
thread_id,
1042+
num_threads);
1043+
} else {
1044+
ReQuantizeOutput<false, Q_GRAN, BIAS_TYPE> reqObj(
1045+
doNothingObj,
1046+
C_multiplier,
1047+
C_zero_point,
1048+
A_zero_point,
1049+
B_zero_point,
1050+
nullptr, /* row offset buffer */
1051+
col_offsets,
1052+
bias,
1053+
conv_p.OC,
1054+
conv_p.G,
1055+
act_times_w_scale);
1056+
1057+
conv_requant_ref(
1058+
conv_p,
1059+
A,
1060+
B.PackedMat(),
1061+
false,
1062+
C,
1063+
nullptr,
1064+
reqObj,
1065+
thread_id,
1066+
num_threads);
1067+
}
1068+
#endif
10141069
}
10151070

1016-
#define INSTANTIATE_BASE(Q_GRAN, BIAS_TYPE) \
1017-
template FBGEMM_API void \
1018-
depthwise_3d_same_pad<QuantizationGranularity::Q_GRAN>( \
1019-
const conv_param_t<3>& conv_p, \
1020-
int32_t A_zero_point, \
1021-
const uint8_t* A, \
1022-
const int32_t* B_zero_point, \
1023-
const PackedDepthWiseConvMatrix& B, \
1024-
const float* C_multiplier, \
1025-
int32_t C_zero_point, \
1026-
uint8_t* C, \
1027-
const int32_t* col_offsets, \
1028-
const BIAS_TYPE* bias, \
1029-
bool fuse_relu, \
1030-
const float* act_times_w_scale, \
1031-
int thread_id, \
1071+
#define INSTANTIATE_BASE(Q_GRAN, BIAS_TYPE) \
1072+
template FBGEMM_API void \
1073+
depthwise_3d_same_pad<fbgemm::QuantizationGranularity::Q_GRAN>( \
1074+
const fbgemm::conv_param_t<3>& conv_p, \
1075+
int32_t A_zero_point, \
1076+
const uint8_t* A, \
1077+
const int32_t* B_zero_point, \
1078+
const fbgemm::PackedDepthWiseConvMatrix& B, \
1079+
const float* C_multiplier, \
1080+
int32_t C_zero_point, \
1081+
uint8_t* C, \
1082+
const int32_t* col_offsets, \
1083+
const BIAS_TYPE* bias, \
1084+
bool fuse_relu, \
1085+
const float* act_times_w_scale, \
1086+
int thread_id, \
10321087
int num_threads);
10331088

10341089
#define INSTANTIATE_BIAS_T(Q_GRAN) \

src/FbgemmI8DepthwiseAvx2-inl.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
#include <cmath> // for lrintf and sqrt
1414
#include <cstdint>
1515
#include <type_traits> // for is_same
16-
16+
#include "RefImplementations.h"
17+
#include "fbgemm/Fbgemm.h"
1718
#if defined(__x86_64__) || defined(__i386__) || \
1819
(defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86)))
1920
#include <immintrin.h>

0 commit comments

Comments
 (0)