Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 10 additions & 86 deletions src/CodeGen_ARM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1475,17 +1475,13 @@ void CodeGen_ARM::visit(const Store *op) {
is_float16_and_has_feature(elt) ||
elt == Int(8) || elt == Int(16) || elt == Int(32) || elt == Int(64) ||
elt == UInt(8) || elt == UInt(16) || elt == UInt(32) || elt == UInt(64)) {
// TODO(zvookin): Handle vector_bits_*.
const int target_vector_bits = native_vector_bits();
if (vec_bits % 128 == 0) {
type_ok_for_vst = true;
int target_vector_bits = native_vector_bits();
if (target_vector_bits == 0) {
target_vector_bits = 128;
}
intrin_type = intrin_type.with_lanes(target_vector_bits / t.bits());
} else if (vec_bits % 64 == 0) {
type_ok_for_vst = true;
auto intrin_bits = (vec_bits % 128 == 0 || target.has_feature(Target::SVE2)) ? 128 : 64;
auto intrin_bits = (vec_bits % 128 == 0 || target.has_feature(Target::SVE2)) ? target_vector_bits : 64;
intrin_type = intrin_type.with_lanes(intrin_bits / t.bits());
}
}
Expand All @@ -1494,7 +1490,9 @@ void CodeGen_ARM::visit(const Store *op) {
if (ramp && is_const_one(ramp->stride) &&
shuffle && shuffle->is_interleave() &&
type_ok_for_vst &&
2 <= shuffle->vectors.size() && shuffle->vectors.size() <= 4) {
2 <= shuffle->vectors.size() && shuffle->vectors.size() <= 4 &&
// TODO: we could handle predicated_store once shuffle_vector gets robust for scalable vectors
!is_predicated_store) {

const int num_vecs = shuffle->vectors.size();
vector<Value *> args(num_vecs);
Expand All @@ -1513,7 +1511,6 @@ void CodeGen_ARM::visit(const Store *op) {
for (int i = 0; i < num_vecs; ++i) {
args[i] = codegen(shuffle->vectors[i]);
}
Value *store_pred_val = codegen(op->predicate);

bool is_sve = target.has_feature(Target::SVE2);

Expand Down Expand Up @@ -1559,8 +1556,8 @@ void CodeGen_ARM::visit(const Store *op) {
llvm::FunctionCallee fn = module->getOrInsertFunction(instr.str(), fn_type);
internal_assert(fn);

// SVE2 supports predication for smaller than whole vector size.
internal_assert(target.has_feature(Target::SVE2) || (t.lanes() >= intrin_type.lanes()));
// Scalable vector supports predication for smaller than whole vector size.
internal_assert(target_vscale() > 0 || (t.lanes() >= intrin_type.lanes()));

for (int i = 0; i < t.lanes(); i += intrin_type.lanes()) {
Expr slice_base = simplify(ramp->base + i * num_vecs);
Expand All @@ -1581,15 +1578,10 @@ void CodeGen_ARM::visit(const Store *op) {
slice_args.push_back(ConstantInt::get(i32_t, alignment));
} else {
if (is_sve) {
// Set the predicate argument
// Set the predicate argument to mask active lanes
auto active_lanes = std::min(t.lanes() - i, intrin_type.lanes());
Value *vpred_val;
if (is_predicated_store) {
vpred_val = slice_vector(store_pred_val, i, intrin_type.lanes());
} else {
Expr vpred = make_vector_predicate_1s_0s(active_lanes, intrin_type.lanes() - active_lanes);
vpred_val = codegen(vpred);
}
Expr vpred = make_vector_predicate_1s_0s(active_lanes, intrin_type.lanes() - active_lanes);
Value *vpred_val = codegen(vpred);
slice_args.push_back(vpred_val);
}
// Set the pointer argument
Expand Down Expand Up @@ -1810,74 +1802,6 @@ void CodeGen_ARM::visit(const Load *op) {
CodeGen_Posix::visit(op);
return;
}
} else if (stride && (2 <= stride->value && stride->value <= 4)) {
// Structured load ST2/ST3/ST4 of SVE

Expr base = ramp->base;
ModulusRemainder align = op->alignment;

int aligned_stride = gcd(stride->value, align.modulus);
int offset = 0;
if (aligned_stride == stride->value) {
offset = mod_imp((int)align.remainder, aligned_stride);
} else {
const Add *add = base.as<Add>();
if (const IntImm *add_c = add ? add->b.as<IntImm>() : base.as<IntImm>()) {
offset = mod_imp(add_c->value, stride->value);
}
}

if (offset) {
base = simplify(base - offset);
}

Value *load_pred_val = codegen(op->predicate);

// We need to slice the result in to native vector lanes to use sve intrin.
// LLVM will optimize redundant ld instructions afterwards
const int slice_lanes = target.natural_vector_size(op->type);
vector<Value *> results;
for (int i = 0; i < op->type.lanes(); i += slice_lanes) {
int load_base_i = i * stride->value;
Expr slice_base = simplify(base + load_base_i);
Expr slice_index = Ramp::make(slice_base, stride, slice_lanes);
std::ostringstream instr;
instr << "llvm.aarch64.sve.ld"
<< stride->value
<< ".sret.nxv"
<< slice_lanes
<< (op->type.is_float() ? 'f' : 'i')
<< op->type.bits();
llvm::Type *elt = llvm_type_of(op->type.element_of());
llvm::Type *slice_type = get_vector_type(elt, slice_lanes);
StructType *sret_type = StructType::get(module->getContext(), std::vector(stride->value, slice_type));
std::vector<llvm::Type *> arg_types{get_vector_type(i1_t, slice_lanes), ptr_t};
llvm::FunctionType *fn_type = FunctionType::get(sret_type, arg_types, false);
FunctionCallee fn = module->getOrInsertFunction(instr.str(), fn_type);

// Set the predicate argument
int active_lanes = std::min(op->type.lanes() - i, slice_lanes);

Expr vpred = make_vector_predicate_1s_0s(active_lanes, slice_lanes - active_lanes);
Value *vpred_val = codegen(vpred);
vpred_val = convert_fixed_or_scalable_vector_type(vpred_val, get_vector_type(vpred_val->getType()->getScalarType(), slice_lanes));
if (is_predicated_load) {
Value *sliced_load_vpred_val = slice_vector(load_pred_val, i, slice_lanes);
vpred_val = builder->CreateAnd(vpred_val, sliced_load_vpred_val);
}

Value *elt_ptr = codegen_buffer_pointer(op->name, op->type.element_of(), slice_base);
CallInst *load_i = builder->CreateCall(fn, {vpred_val, elt_ptr});
add_tbaa_metadata(load_i, op->name, slice_index);
// extract one element out of returned struct
Value *extracted = builder->CreateExtractValue(load_i, offset);
results.push_back(extracted);
}

// Retrieve original lanes
value = concat_vectors(results);
value = slice_vector(value, 0, op->type.lanes());
return;
} else if (op->index.type().is_vector()) {
// General Gather Load

Expand Down
82 changes: 45 additions & 37 deletions test/correctness/simd_op_check_sve2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ class SimdOpCheckArmSve : public SimdOpCheckTest {
void add_tests() override {
check_arm_integer();
check_arm_float();
check_arm_load_store();
if (Halide::Internal::get_llvm_version() >= 220) {
check_arm_load_store();
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Skipping this test is a red flag. We can't degrade codegen quality on supported LLVM versions, especially not across all released LLVM versions. That test (and perhaps this patch) should be adjusted to confirm superior codegen on LLVM >=22 (or more broadly), but it must also confirm the existing behavior on LLVM<22.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the feedback. I've updated so that test scope before this PR is kept with old LLVM. The test cases enabled by this PR are performed only with LLVM >= 22.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest to remove the scalar load/store test by the reason below

        // We skip scalar load/store test due to the following challenges.
        // The rule by which LLVM selects instruction does not seem simple.
        // For example, ld1, ldr, or ldp is used for instruction and z or q register is used for operand,
        // depending on data type, vscale, what is performed before/after load, and LLVM version.
        // The other thing is, load/store instruction appears in other place than we want to check,
        // which makes it prone to false-positive detection as we only search strings line-by-line.

In addition, this is not a test for SIMD. The previous one looked like passing but with false-positive.

check_arm_pairwise();
}

Expand Down Expand Up @@ -677,6 +679,9 @@ class SimdOpCheckArmSve : public SimdOpCheckTest {
vector<tuple<Type, CastFuncTy>> test_params = {
{Int(8), in_i8}, {Int(16), in_i16}, {Int(32), in_i32}, {Int(64), in_i64}, {UInt(8), in_u8}, {UInt(16), in_u16}, {UInt(32), in_u32}, {UInt(64), in_u64}, {Float(16), in_f16}, {Float(32), in_f32}, {Float(64), in_f64}};

const int base_vec_bits = has_sve() ? target.vector_bits : 128;
const int vscale = base_vec_bits / 128;

for (const auto &[elt, in_im] : test_params) {
const int bits = elt.bits();
if ((elt == Float(16) && !is_float16_supported()) ||
Expand All @@ -687,8 +692,9 @@ class SimdOpCheckArmSve : public SimdOpCheckTest {
// LD/ST - Load/Store
for (int width = 64; width <= 64 * 4; width *= 2) {
const int total_lanes = width / bits;
const int instr_lanes = min(total_lanes, 128 / bits);
if (instr_lanes < 2) continue; // bail out scalar op
const int vector_lanes = base_vec_bits / bits;
const int instr_lanes = min(total_lanes, vector_lanes);
if (instr_lanes < 2 || (vector_lanes / vscale < 2)) continue; // bail out scalar and <vscale x 1 x ty>

// In case of arm32, instruction selection looks inconsistent due to optimization by LLVM
AddTestFunctor add(*this, bits, total_lanes, target.bits == 64);
Expand All @@ -712,44 +718,59 @@ class SimdOpCheckArmSve : public SimdOpCheckTest {
}
}

// LD2/ST2 - Load/Store two-element structures
int base_vec_bits = has_sve() ? target.vector_bits : 128;
for (int width = base_vec_bits; width <= base_vec_bits * 4; width *= 2) {
// LDn - Structured Load strided elements
for (int stride = 2; stride <= 4; ++stride) {

for (int factor = 1; factor <= 4; factor *= 2) {
const int vector_lanes = base_vec_bits * factor / bits;

// In StageStridedLoads.cp (stride < r->lanes) is the condition for staging to happen
// See https://github.com/halide/Halide/issues/8819
if (vector_lanes <= stride) continue;

AddTestFunctor add_ldn(*this, bits, vector_lanes);

Expr load_n = in_im(x * stride) + in_im(x * stride + stride - 1);

const string ldn_str = "ld" + to_string(stride);
if (has_sve()) {
add_ldn({get_sve_ls_instr(ldn_str, bits)}, vector_lanes, load_n);
} else {
add_ldn(sel_op("v" + ldn_str + ".", ldn_str), load_n);
}
}
}

// ST2 - Store two-element structures
for (int width = base_vec_bits * 2; width <= base_vec_bits * 8; width *= 2) {
const int total_lanes = width / bits;
const int vector_lanes = total_lanes / 2;
const int instr_lanes = min(vector_lanes, base_vec_bits / bits);
if (instr_lanes < 2) continue; // bail out scalar op
if (instr_lanes < 2 || (vector_lanes / vscale < 2)) continue; // bail out scalar and <vscale x 1 x ty>

AddTestFunctor add_ldn(*this, bits, vector_lanes);
AddTestFunctor add_stn(*this, bits, instr_lanes, total_lanes);

Func tmp1, tmp2;
tmp1(x) = cast(elt, x);
tmp1.compute_root();
tmp2(x, y) = select(x % 2 == 0, tmp1(x / 2), tmp1(x / 2 + 16));
tmp2.compute_root().vectorize(x, total_lanes);
Expr load_2 = in_im(x * 2) + in_im(x * 2 + 1);
Expr store_2 = tmp2(0, 0) + tmp2(0, 127);

if (has_sve()) {
// TODO(inssue needed): Added strided load support.
#if 0
add_ldn({get_sve_ls_instr("ld2", bits)}, vector_lanes, load_2);
#endif
add_stn({get_sve_ls_instr("st2", bits)}, total_lanes, store_2);
} else {
add_ldn(sel_op("vld2.", "ld2"), load_2);
add_stn(sel_op("vst2.", "st2"), store_2);
}
}

// Also check when the two expressions interleaved have a common
// subexpression, which results in a vector var being lifted out.
for (int width = base_vec_bits; width <= base_vec_bits * 4; width *= 2) {
for (int width = base_vec_bits * 2; width <= base_vec_bits * 4; width *= 2) {
const int total_lanes = width / bits;
const int vector_lanes = total_lanes / 2;
const int instr_lanes = Instruction::get_instr_lanes(bits, vector_lanes, target);
if (instr_lanes < 2) continue; // bail out scalar op
if (instr_lanes < 2 || (vector_lanes / vscale < 2)) continue; // bail out scalar and <vscale x 1 x ty>

AddTestFunctor add_stn(*this, bits, instr_lanes, total_lanes);

Expand All @@ -768,14 +789,13 @@ class SimdOpCheckArmSve : public SimdOpCheckTest {
}
}

// LD3/ST3 - Store three-element structures
for (int width = 192; width <= 192 * 4; width *= 2) {
// ST3 - Store three-element structures
for (int width = base_vec_bits * 3; width <= base_vec_bits * 3 * 2; width *= 2) {
const int total_lanes = width / bits;
const int vector_lanes = total_lanes / 3;
const int instr_lanes = Instruction::get_instr_lanes(bits, vector_lanes, target);
if (instr_lanes < 2) continue; // bail out scalar op
if (instr_lanes < 2 || (vector_lanes / vscale < 2)) continue; // bail out scalar and <vscale x 1 x ty>

AddTestFunctor add_ldn(*this, bits, vector_lanes);
AddTestFunctor add_stn(*this, bits, instr_lanes, total_lanes);

Func tmp1, tmp2;
Expand All @@ -785,29 +805,22 @@ class SimdOpCheckArmSve : public SimdOpCheckTest {
x % 3 == 1, tmp1(x / 3 + 16),
tmp1(x / 3 + 32));
tmp2.compute_root().vectorize(x, total_lanes);
Expr load_3 = in_im(x * 3) + in_im(x * 3 + 1) + in_im(x * 3 + 2);
Expr store_3 = tmp2(0, 0) + tmp2(0, 127);

if (has_sve()) {
// TODO(issue needed): Added strided load support.
#if 0
add_ldn({get_sve_ls_instr("ld3", bits)}, vector_lanes, load_3);
add_stn({get_sve_ls_instr("st3", bits)}, total_lanes, store_3);
#endif
} else {
add_ldn(sel_op("vld3.", "ld3"), load_3);
add_stn(sel_op("vst3.", "st3"), store_3);
}
}

// LD4/ST4 - Store four-element structures
for (int width = 256; width <= 256 * 4; width *= 2) {
// ST4 - Store four-element structures
for (int width = base_vec_bits * 4; width <= base_vec_bits * 4 * 2; width *= 2) {
const int total_lanes = width / bits;
const int vector_lanes = total_lanes / 4;
const int instr_lanes = Instruction::get_instr_lanes(bits, vector_lanes, target);
if (instr_lanes < 2) continue; // bail out scalar op
if (instr_lanes < 2 || (vector_lanes / vscale < 2)) continue; // bail out scalar and <vscale x 1 x ty>

AddTestFunctor add_ldn(*this, bits, vector_lanes);
AddTestFunctor add_stn(*this, bits, instr_lanes, total_lanes);

Func tmp1, tmp2;
Expand All @@ -818,17 +831,11 @@ class SimdOpCheckArmSve : public SimdOpCheckTest {
x % 4 == 2, tmp1(x / 4 + 32),
tmp1(x / 4 + 48));
tmp2.compute_root().vectorize(x, total_lanes);
Expr load_4 = in_im(x * 4) + in_im(x * 4 + 1) + in_im(x * 4 + 2) + in_im(x * 4 + 3);
Expr store_4 = tmp2(0, 0) + tmp2(0, 127);

if (has_sve()) {
// TODO(issue needed): Added strided load support.
#if 0
add_ldn({get_sve_ls_instr("ld4", bits)}, vector_lanes, load_4);
add_stn({get_sve_ls_instr("st4", bits)}, total_lanes, store_4);
#endif
} else {
add_ldn(sel_op("vld4.", "ld4"), load_4);
add_stn(sel_op("vst4.", "st4"), store_4);
}
}
Expand All @@ -838,7 +845,7 @@ class SimdOpCheckArmSve : public SimdOpCheckTest {
for (int width = 64; width <= 64 * 4; width *= 2) {
const int total_lanes = width / bits;
const int instr_lanes = min(total_lanes, 128 / bits);
if (instr_lanes < 2) continue; // bail out scalar op
if (instr_lanes < 2 || (total_lanes / vscale < 2)) continue; // bail out scalar and <vscale x 1 x ty>

AddTestFunctor add(*this, bits, total_lanes);
Expr index = clamp(cast<int>(in_im(x)), 0, W - 1);
Expand Down Expand Up @@ -1295,6 +1302,7 @@ class SimdOpCheckArmSve : public SimdOpCheckTest {

auto ext = Internal::get_output_info(target);
std::map<OutputFileType, std::string> outputs = {
{OutputFileType::stmt, file_name + ext.at(OutputFileType::stmt).extension},
{OutputFileType::llvm_assembly, file_name + ext.at(OutputFileType::llvm_assembly).extension},
{OutputFileType::c_header, file_name + ext.at(OutputFileType::c_header).extension},
{OutputFileType::object, file_name + ext.at(OutputFileType::object).extension},
Expand Down
Loading