Skip to content

Commit 96d6de4

Browse files
committed
feat: only use GEMM stateless exec for fixed-format
Make GemmCommon::execute_stateless() a no-op in release-mode and an assert(0) in debug-mode for non-fixed-format kernels. This reflects the reality that stateless, thread-safe execution is only valid for fixed-format kernels. Change-Id: I1ba1956e6a27a05fc1bb0c95b62996ca1c4833a6 Signed-off-by: Siddhartha Menon <[email protected]> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/14375 Tested-by: Arm Jenkins <[email protected]> Benchmark: Arm Jenkins <[email protected]> Reviewed-by: Gunes Bayir <[email protected]> Comments-Addressed: Arm Jenkins <[email protected]>
1 parent 4a89ff5 commit 96d6de4

File tree

9 files changed

+69
-77
lines changed

9 files changed

+69
-77
lines changed

src/core/NEON/kernels/arm_gemm/gemm_hybrid.hpp

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,8 @@ class GemmHybrid : public GemmCommon<To, To, Tr> {
144144
return true;
145145
}
146146

147-
// Common execution logic.
148-
void execute_common(const ndcoord_t &work_range, const ndcoord_t &, int, GemmArrays<To, To, Tr>& g_arrays) {
147+
// Execute
148+
void execute(const ndcoord_t &work_range, const ndcoord_t &, int) override {
149149
#ifdef CYCLE_PROFILING
150150
profiler prof;
151151
#endif
@@ -156,6 +156,8 @@ class GemmHybrid : public GemmCommon<To, To, Tr> {
156156
static_assert(std::is_same<To, Toi>::value, "gemm_native: Operand types must be the same.");
157157
static_assert(std::is_same<Tr, Tri>::value, "gemm_native: Result types must be the same.");
158158

159+
auto &g_arrays = this->_gemm_arrays;
160+
159161
/* For now, each work item implies all the K for a given output
160162
* pixel (so we don't need to synchronize access to the output
161163
* array). So separate the loop over K blocks here. */
@@ -208,16 +210,6 @@ class GemmHybrid : public GemmCommon<To, To, Tr> {
208210

209211
}
210212

211-
// Stateless execute
212-
void execute_stateless(const ndcoord_t &work_range, const ndcoord_t &thread_locator, int threadid, GemmArrays<To, To, Tr> &g_arrays) override {
213-
return execute_common(work_range, thread_locator, threadid, g_arrays);
214-
}
215-
216-
// Execute
217-
void execute(const ndcoord_t &work_range, const ndcoord_t & thread_locator, int threadid) override {
218-
execute_common(work_range, thread_locator, threadid, this->_gemm_arrays);
219-
}
220-
221213
// Interface implementation - pretransposed
222214
bool B_is_pretransposed() const override {
223215
return true;

src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -586,6 +586,8 @@ class GemmHybridIndirect : public GemmCommon<To, Tw, Tr> {
586586

587587
// Stateless execute
588588
void execute_stateless(const ndcoord_t &work_range, const ndcoord_t &thread_locator, int threadid, GemmArrays<To, Tw, Tr>& g_arrays) override {
589+
assert(FixedFormat);
590+
589591
return execute_common(work_range, thread_locator, threadid, g_arrays);
590592
}
591593

src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -166,13 +166,15 @@ class GemmHybridQuantized : public GemmCommon<To, To, Tr> {
166166
return true;
167167
}
168168

169-
// Common execution logic.
170-
void execute_common(const ndcoord_t &work_range, const ndcoord_t &, int threadid, GemmArrays<To, To, Tr> &g_arrays) {
169+
// Execute
170+
void execute(const ndcoord_t &work_range, const ndcoord_t &, int threadid) override {
171171
#ifdef CYCLE_PROFILING
172172
profiler prof;
173173
#endif
174174
strategy strat(_ci);
175175

176+
auto &g_arrays = this->_gemm_arrays;
177+
176178
void *working_space = g_arrays._workspace;
177179
auto working_int = reinterpret_cast<uintptr_t>(working_space);
178180

@@ -243,16 +245,6 @@ class GemmHybridQuantized : public GemmCommon<To, To, Tr> {
243245
}
244246
}
245247

246-
// Stateless execute
247-
void execute_stateless(const ndcoord_t &work_range, const ndcoord_t &thread_locator, int threadid, GemmArrays<To, To, Tr> &g_arrays) override {
248-
return execute_common(work_range, thread_locator, threadid, g_arrays);
249-
}
250-
251-
// Execute
252-
void execute(const ndcoord_t &work_range, const ndcoord_t & thread_locator, int threadid) override {
253-
execute_common(work_range, thread_locator, threadid, this->_gemm_arrays);
254-
}
255-
256248
// Working space needed for intermediate result buffers.
257249
size_t get_working_size() const override {
258250
return (_nthreads * strategy::out_height() * _Nsize * sizeof(Tri));

src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -474,9 +474,9 @@ class GemmInterleaved : public GemmCommon<Tlo, Tro, Tr> {
474474
unsigned int get_col_sum_size() const {
475475
if (std::is_same<OutputStage, Requantize32>::value) {
476476
return _Nsize * _nmulti * sizeof(int32_t);
477-
} else {
478-
return 0;
479477
}
478+
479+
return 0;
480480
}
481481

482482
/* We will need to walk through the blocks of B in a few contexts, so
@@ -576,9 +576,9 @@ class GemmInterleaved : public GemmCommon<Tlo, Tro, Tr> {
576576
size_t get_c_working_size() const {
577577
if (MergeStep) {
578578
return ROUND_UP(sizeof(Tri) * _x_block * strategy::out_height());
579-
} else {
580-
return 0;
581579
}
580+
581+
return 0;
582582
}
583583

584584
// Accumulation buffer size
@@ -1129,6 +1129,8 @@ class GemmInterleaved : public GemmCommon<Tlo, Tro, Tr> {
11291129

11301130
// Stateless execute
11311131
void execute_stateless(const ndcoord_t &work_range, const ndcoord_t &thread_locator, int threadid, GemmArrays<Tlo, Tro, Tr> &g_arrays) override {
1132+
assert(FixedFormat);
1133+
11321134
return execute_common(work_range, thread_locator, threadid, g_arrays);
11331135
}
11341136

src/core/NEON/kernels/arm_gemm/gemv_batched.hpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,8 @@ class GemvBatched : public GemmCommon<To, To, Tr> {
6464
_subgemm->set_nthreads(nthreads);
6565
}
6666

67-
void execute_stateless(const ndcoord_t &work_range, const ndcoord_t &thread_locator, int threadid, GemmArrays<To, To, Tr> &) override {
68-
_subgemm->execute(work_range, thread_locator, threadid);
69-
}
70-
7167
void execute(const ndcoord_t &work_range, const ndcoord_t &thread_locator, int threadid) override {
72-
execute_stateless(work_range, thread_locator, threadid, this->_gemm_arrays);
68+
_subgemm->execute(work_range, thread_locator, threadid);
7369
}
7470

7571
size_t get_working_size() const override {

src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -136,13 +136,15 @@ class GemvPretransposed : public GemmCommon<To, To, Tr> {
136136
return { iceildiv(_args._Nsize, strategy::out_width()) * _args._nmulti };
137137
}
138138

139-
// Common execution logic.
140-
void execute_common(const ndcoord_t &work_range, const ndcoord_t &, int, GemmArrays<To, To, Tr>& g_arrays) {
139+
// Actually execute the GEMV.
140+
void execute(const ndcoord_t &work_range, const ndcoord_t &, int) override {
141141
#ifdef CYCLE_PROFILING
142142
profiler prof;
143143
#endif
144144
strategy strat(_args._ci);
145145

146+
auto& g_arrays = this->_gemm_arrays;
147+
146148
const auto start = work_range.get_position(0);
147149
const auto end = work_range.get_position_end(0);
148150

@@ -184,16 +186,6 @@ class GemvPretransposed : public GemmCommon<To, To, Tr> {
184186
}
185187
}
186188

187-
// Stateless execute
188-
void execute_stateless(const ndcoord_t &work_range, const ndcoord_t &thread_locator, int threadid, GemmArrays<To, To, Tr> &g_arrays) override {
189-
return execute_common(work_range, thread_locator, threadid, g_arrays);
190-
}
191-
192-
// Actually execute the GEMV.
193-
void execute(const ndcoord_t &work_range, const ndcoord_t &thread_locator, int threadid) override {
194-
execute_common(work_range, thread_locator, threadid, this->_gemm_arrays);
195-
}
196-
197189
/* Pretransposed interface implementation */
198190
bool B_is_pretransposed() const override {
199191
return true;

src/cpu/kernels/assembly/gemm_arrays.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,8 @@ struct GemmArrays : public IGemmArrays
108108

109109
GemmArrays(const GemmArrays<To, Tw, Tr> &) = default;
110110
GemmArrays &operator=(const GemmArrays<To, Tw, Tr> &) = default;
111-
GemmArrays(GemmArrays<To, Tw, Tr> &&) = delete;
112-
GemmArrays &operator=(GemmArrays<To, Tw, Tr> &&) = delete;
111+
GemmArrays(GemmArrays<To, Tw, Tr> &&) = default;
112+
GemmArrays &operator=(GemmArrays<To, Tw, Tr> &&) = default;
113113
~GemmArrays() override = default;
114114

115115
/* Pass in the pointers to the arrays to be operated on and their

src/cpu/kernels/assembly/gemm_common.hpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include "convolution_parameters.hpp"
3131
#include "gemm_arrays.hpp"
3232
#include "ndrange.hpp"
33+
3334
#include <cstddef>
3435

3536
namespace arm_gemm
@@ -307,10 +308,14 @@ class GemmCommon : public IGemmCommon
307308
* @param [in] threadid a unique threadid
308309
* @param [out] GemmArrays structure containing the input/output addresses, and stride info
309310
*/
310-
virtual void execute_stateless(const ndcoord_t &work_range,
311-
const ndcoord_t &thread_locator,
312-
int threadid,
313-
GemmArrays<To, Tw, Tr> &gemm_array) = 0;
311+
virtual void execute_stateless(const ndcoord_t &,
312+
const ndcoord_t &,
313+
int,
314+
GemmArrays<To, Tw, Tr> &)
315+
{
316+
// This must be overridden in the derived class to be used
317+
assert(0);
318+
}
314319
};
315320
} // namespace arm_gemm
316321

src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp

Lines changed: 37 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -789,40 +789,51 @@ void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::run(ITensorPack &
789789
multi_stride_a = 0;
790790
}
791791

792-
Tensor in0_tensor;
793-
in0_tensor.allocator()->init(*(a->info()));
794-
in0_tensor.allocator()->import_memory(const_cast<TypeInput *>(in0_ptr));
792+
// Set gemm parameters
793+
_gemm_kernel_asm->set_arrays(in0_ptr, lda, batch_stride_a, multi_stride_a, in1_ptr, ldb, multi_stride_b, out_ptr,
794+
ldd, batch_stride_d, multi_stride_d, bias, 0);
795795

796-
Tensor in1_tensor;
797-
if (b)
796+
// Need to pack the input/output pointers separately to use the thread-safe,
797+
// stateless-execution interface for fixed-format kernels.
798+
if (_gemm_info.fixed_format)
798799
{
799-
in1_tensor.allocator()->init(*(b->info()));
800-
in1_tensor.allocator()->import_memory(const_cast<TypeWeight *>(in1_ptr));
801-
}
800+
Tensor in0_tensor;
801+
in0_tensor.allocator()->init(*(a->info()));
802+
in0_tensor.allocator()->import_memory(const_cast<TypeInput *>(in0_ptr));
802803

803-
Tensor bias_tensor;
804-
if (c)
805-
{
806-
bias_tensor.allocator()->init(*(c->info()));
807-
bias_tensor.allocator()->import_memory(bias);
808-
}
804+
Tensor in1_tensor;
805+
if (b)
806+
{
807+
in1_tensor.allocator()->init(*(b->info()));
808+
in1_tensor.allocator()->import_memory(const_cast<TypeWeight *>(in1_ptr));
809+
}
809810

810-
Tensor out_tensor;
811-
out_tensor.allocator()->init(*(d->info()));
812-
out_tensor.allocator()->import_memory(out_ptr);
811+
Tensor bias_tensor;
812+
if (c)
813+
{
814+
bias_tensor.allocator()->init(*(c->info()));
815+
bias_tensor.allocator()->import_memory(bias);
816+
}
813817

814-
ITensorPack gemm_pack{{ACL_SRC_0, &in0_tensor},
815-
{ACL_SRC_1, &in1_tensor},
816-
{ACL_SRC_2, &bias_tensor},
817-
{ACL_SRC_3, workspace.get()},
818-
{ACL_DST, &out_tensor}};
818+
Tensor out_tensor;
819+
out_tensor.allocator()->init(*(d->info()));
820+
out_tensor.allocator()->import_memory(out_ptr);
819821

820-
// Set gemm parameters
821-
_gemm_kernel_asm->set_arrays(in0_ptr, lda, batch_stride_a, multi_stride_a, in1_ptr, ldb, multi_stride_b, out_ptr,
822-
ldd, batch_stride_d, multi_stride_d, bias, 0);
822+
ITensorPack gemm_pack{{ACL_SRC_0, &in0_tensor},
823+
{ACL_SRC_1, &in1_tensor},
824+
{ACL_SRC_2, &bias_tensor},
825+
{ACL_SRC_3, workspace.get()},
826+
{ACL_DST, &out_tensor}};
827+
828+
// Schedule thread-safe stateless execution
829+
NEScheduler::get().schedule_op(_optimised_kernel.get(), scheduling_hint, _optimised_kernel->window(),
830+
gemm_pack);
831+
832+
return;
833+
}
823834

824835
// Schedule
825-
NEScheduler::get().schedule_op(_optimised_kernel.get(), scheduling_hint, _optimised_kernel->window(), gemm_pack);
836+
NEScheduler::get().schedule(_optimised_kernel.get(), scheduling_hint);
826837
}
827838

828839
template <typename TypeInput, typename TypeWeight, typename TypeOutput>

0 commit comments

Comments
 (0)