Skip to content

Commit 8bc9600

Browse files
committed
Allow AArch64 JIT reorder and relax stride checks for f16 paths. AArch64 jit_uni_reorder now treats pure f16→f16 as valid (previously only f32<->f16 passed), preventing unnecessary fallback to reference. For f16 cases, the small‑stride requirement is relaxed so blocked/large‑stride layouts can stay on the JIT path instead of degrading to ref. This should reduce ref reorder usage and keep f16 workloads on optimized kernels on AArch64.
1 parent 33bfbec commit 8bc9600

25 files changed

+5420
-4090
lines changed

src/cpu/aarch64/acl_reorder.cpp

Lines changed: 0 additions & 52 deletions
This file was deleted.

src/cpu/aarch64/acl_reorder.hpp

Lines changed: 9 additions & 226 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
* Copyright 2023-2025 Arm Ltd. and affiliates
2+
* Copyright 2025 Arm Ltd. and affiliates
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -13,239 +13,22 @@
1313
* See the License for the specific language governing permissions and
1414
* limitations under the License.
1515
*******************************************************************************/
16-
#ifndef CPU_ACL_REORDER_HPP
17-
#define CPU_ACL_REORDER_HPP
1816

19-
#include "arm_compute/core/Types.h"
20-
#include "common/utils.hpp"
21-
#include "cpu/acl/acl_utils.hpp"
22-
#include "cpu/aarch64/cpu_isa_traits.hpp"
23-
#include "cpu/reorder/cpu_reorder_pd.hpp"
17+
#ifndef CPU_AARCH64_ACL_REORDER_HPP
18+
#define CPU_AARCH64_ACL_REORDER_HPP
2419

20+
// Keep include path compatibility with code that expects this header.
21+
#include "cpu/aarch64/reorder/acl_reorder.hpp"
22+
23+
// Provide the expected cpu::acl namespace alias used by common headers.
2524
namespace dnnl {
2625
namespace impl {
2726
namespace cpu {
2827
namespace acl {
29-
30-
struct acl_reorder_obj_t {
31-
arm_compute::NEReorderLayer reorder;
32-
arm_compute::Tensor src_tensor;
33-
arm_compute::Tensor dst_tensor;
34-
arm_compute::WeightFormat src_wf;
35-
arm_compute::WeightFormat dst_wf;
36-
};
37-
38-
struct acl_reorder_conf_t {
39-
arm_compute::TensorInfo src_info;
40-
arm_compute::TensorInfo dst_info;
41-
arm_compute::WeightFormat src_wf;
42-
arm_compute::WeightFormat dst_wf;
43-
};
44-
45-
struct acl_reorder_resource_t : public resource_t {
46-
acl_reorder_resource_t()
47-
: acl_obj_(utils::make_unique<acl_reorder_obj_t>()) {}
48-
49-
status_t configure(const acl_reorder_conf_t &app) {
50-
if (!acl_obj_) return status::out_of_memory;
51-
52-
// Init Compute Library tensors based on info from descriptor
53-
acl_obj_->src_tensor.allocator()->init(app.src_info);
54-
acl_obj_->dst_tensor.allocator()->init(app.dst_info);
55-
56-
// clang-format off
57-
acl_obj_->reorder.configure(
58-
&acl_obj_->src_tensor,
59-
&acl_obj_->dst_tensor,
60-
app.src_wf,
61-
app.dst_wf
62-
);
63-
// clang-format on
64-
65-
return status::success;
66-
}
67-
68-
acl_reorder_obj_t &get_acl_obj() const { return *acl_obj_; }
69-
DNNL_DISALLOW_COPY_AND_ASSIGN(acl_reorder_resource_t);
70-
71-
private:
72-
std::unique_ptr<acl_reorder_obj_t> acl_obj_;
73-
}; // acl_reorder_resource_t
74-
75-
struct acl_reorder_fwd_t : public primitive_t {
76-
using primitive_t::primitive_t;
77-
struct pd_t : public cpu_reorder_pd_t {
78-
79-
using cpu_reorder_pd_t::cpu_reorder_pd_t;
80-
81-
DECLARE_COMMON_PD_T("acl", acl_reorder_fwd_t);
82-
83-
static status_t create(reorder_pd_t **reorder_pd, engine_t *engine,
84-
const primitive_attr_t *attr, engine_t *src_engine,
85-
const memory_desc_t *src_md, engine_t *dst_engine,
86-
const memory_desc_t *dst_md) {
87-
88-
using namespace acl_utils;
89-
90-
// ACL reorder support f32->f32 and f32->bf16
91-
bool ok = src_md->data_type == data_type::f32
92-
&& utils::one_of(
93-
dst_md->data_type, data_type::f32, data_type::bf16)
94-
&& attr->has_default_values();
95-
96-
if (!ok) return status::unimplemented;
97-
98-
if (!attr->scales_.has_default_values(DNNL_ARG_DST)) {
99-
int mask = attr->scales_.get_mask(DNNL_ARG_DST);
100-
const memory_desc_wrapper input_d(src_md);
101-
if (input_d.has_runtime_dims_or_strides() && mask > 0)
102-
return status::unimplemented;
103-
}
104-
105-
// Create and check primitive descriptor
106-
auto _pd = make_unique_pd<pd_t>(attr, src_engine->kind(), src_md,
107-
dst_engine->kind(), dst_md);
108-
if (_pd == nullptr) return status::out_of_memory;
109-
if (_pd->init(engine, src_engine, dst_engine) != status::success) {
110-
return status::unimplemented;
111-
}
112-
113-
// In case we have two or four dimensions, we can't have one of the
114-
// two first dimensions as 1. This is valid for f32->f32 and f32->bf16.
115-
if (dst_md->dims[0] == 1 || dst_md->dims[1] == 1) {
116-
return status::unimplemented;
117-
}
118-
119-
auto src_tag = memory_desc_matches_one_of_tag(
120-
*src_md, format_tag::ab, format_tag::ba, format_tag::cdba);
121-
ACL_CHECK_SUPPORT(format_tag::undef == src_tag,
122-
"Only ab, ba or cdba source formats supported");
123-
124-
auto dst_tag = memory_desc_matches_one_of_tag(*dst_md,
125-
format_tag::BA8b4a, format_tag::BA4b4a, format_tag::Ab4a,
126-
format_tag::Ab8a, format_tag::Acdb8a, format_tag::Acdb4a);
127-
ACL_CHECK_SUPPORT(format_tag::undef == dst_tag,
128-
"Only Ab4a/Ab8a, BA8b4a/BA4b4a and Acdb8a/Acdb4a "
129-
"destination formats supported");
130-
131-
if (dst_tag == format_tag::BA4b4a || dst_tag == format_tag::Acdb4a
132-
|| dst_tag == format_tag::Ab4a) {
133-
_pd->app_.dst_wf = arm_compute::WeightFormat::OHWIo4;
134-
} else if (aarch64::mayiuse(aarch64::sve_256)
135-
&& (dst_tag == format_tag::BA8b4a
136-
|| dst_tag == format_tag::Acdb8a
137-
|| dst_tag == format_tag::Ab8a)) {
138-
_pd->app_.dst_wf = arm_compute::WeightFormat::OHWIo8;
139-
} else {
140-
return status::unimplemented;
141-
}
142-
143-
arm_compute::TensorShape acl_tensor_shape_in;
144-
arm_compute::TensorShape acl_tensor_shape_out;
145-
146-
// Switch for 2 or 4 dim tensors
147-
switch (src_md->ndims) {
148-
case 2: {
149-
if (src_tag == format_tag::ab
150-
&& dst_md->data_type == data_type::bf16
151-
&& utils::one_of(dst_tag, format_tag::BA8b4a,
152-
format_tag::BA4b4a)) { // bf16
153-
acl_tensor_shape_in = arm_compute::TensorShape(
154-
src_md->dims[0], src_md->dims[1]);
155-
acl_tensor_shape_out = arm_compute::TensorShape(
156-
dst_md->padded_dims[0], dst_md->padded_dims[1]);
157-
} else if (src_tag == format_tag::ba
158-
&& dst_md->data_type == data_type::f32
159-
&& !utils::one_of(dst_tag, format_tag::BA8b4a,
160-
format_tag::BA4b4a)) { // f32
161-
acl_tensor_shape_in = arm_compute::TensorShape(
162-
src_md->dims[1], src_md->dims[0]);
163-
acl_tensor_shape_out = arm_compute::TensorShape(
164-
dst_md->padded_dims[1], dst_md->padded_dims[0]);
165-
} else {
166-
return status::unimplemented;
167-
}
168-
} break;
169-
case 4: {
170-
// Currently only supporting AxBx1x1 cases
171-
if (dst_md->dims[2] != 1 || dst_md->dims[3] != 1) {
172-
return status::unimplemented;
173-
}
174-
175-
acl_tensor_shape_in = arm_compute::TensorShape(
176-
src_md->dims[3], src_md->dims[2], src_md->dims[1],
177-
src_md->dims[0]);
178-
acl_tensor_shape_out = arm_compute::TensorShape(
179-
dst_md->padded_dims[3], dst_md->padded_dims[2],
180-
dst_md->padded_dims[1], dst_md->padded_dims[0]);
181-
break;
182-
}
183-
default: return status::unimplemented;
184-
}
185-
186-
// Choose the data layout
187-
const auto acl_layout = arm_compute::DataLayout::NCHW;
188-
189-
// Set Source WeightFormat
190-
_pd->app_.src_wf = arm_compute::WeightFormat::OHWI;
191-
192-
// Create ACL tensor infos
193-
const arm_compute::DataType src_acl_data_t
194-
= acl_utils::get_acl_data_t(src_md->data_type);
195-
_pd->app_.src_info = arm_compute::TensorInfo(
196-
acl_tensor_shape_in, 1, src_acl_data_t, acl_layout);
197-
198-
const arm_compute::DataType dst_acl_data_t
199-
= acl_utils::get_acl_data_t(dst_md->data_type);
200-
_pd->app_.dst_info = arm_compute::TensorInfo(
201-
acl_tensor_shape_out, 1, dst_acl_data_t, acl_layout);
202-
203-
ACL_CHECK_VALID(arm_compute::NEReorderLayer::validate(
204-
&_pd->app_.src_info, &_pd->app_.dst_info, _pd->app_.src_wf,
205-
_pd->app_.dst_wf));
206-
207-
// Init scratch memory, not used so 0 in this implementation
208-
_pd->init_scratchpad_md();
209-
210-
return safe_ptr_assign(*reorder_pd, _pd.release());
211-
} // create
212-
213-
friend dnnl::impl::impl_list_item_t;
214-
acl_reorder_conf_t app_;
215-
216-
}; // pd_t
217-
218-
acl_reorder_fwd_t(const pd_t *apd) : primitive_t(apd) {}
219-
220-
status_t create_resource(
221-
engine_t *engine, resource_mapper_t &mapper) const override {
222-
if (mapper.has_resource(this)) return status::success;
223-
224-
auto r = utils::make_unique<acl_reorder_resource_t>();
225-
if (!r) return status::out_of_memory;
226-
227-
// Configure the resource based on information from primitive descriptor
228-
CHECK(r->configure(pd()->app_));
229-
230-
mapper.add(this, std::move(r));
231-
return status::success;
232-
}
233-
234-
status_t execute(const exec_ctx_t &ctx) const override {
235-
return execute_forward(ctx);
236-
}
237-
238-
private:
239-
// To guard the const execute_forward, the mutex must be 'mutable'
240-
mutable std::mutex mtx;
241-
status_t execute_forward(const exec_ctx_t &ctx) const;
242-
const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
243-
244-
}; // acl_reorder_fwd_t
245-
28+
using aarch64::acl_reorder_fwd_t;
24629
} // namespace acl
24730
} // namespace cpu
24831
} // namespace impl
24932
} // namespace dnnl
25033

251-
#endif // CPU_ACL_REORDER_HPP
34+
#endif

0 commit comments

Comments
 (0)