Skip to content

Commit ec810ce

Browse files
add support for scalable vectors
1 parent c6c7a7b commit ec810ce

32 files changed

+432
-20
lines changed

ir/type.cpp

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "ir/state.h"
77
#include "smt/solver.h"
88
#include "util/compiler.h"
9+
#include "util/config.h"
910
#include <array>
1011
#include <cassert>
1112
#include <numeric>
@@ -1089,12 +1090,14 @@ void ArrayType::print(ostream &os) const {
10891090
}
10901091

10911092

1092-
VectorType::VectorType(string &&name, unsigned elements, Type &elementTy)
1093-
: AggregateType(std::move(name), false) {
1094-
assert(elements != 0);
1095-
this->elements = elements;
1093+
VectorType::VectorType(string &&name, unsigned elems, Type &elTy, bool scal)
1094+
: AggregateType(std::move(name), false), scalable(scal), min_elements(elems) {
1095+
assert(elems != 0);
1096+
if (scalable)
1097+
elems *= util::config::vscale_value;
1098+
this->elements = elems;
10961099
defined = true;
1097-
children.resize(elements, &elementTy);
1100+
children.resize(elements, &elTy);
10981101
is_padding.resize(elements, false);
10991102
}
11001103

@@ -1178,8 +1181,13 @@ expr VectorType::enforceVectorType(
11781181
}
11791182

11801183
void VectorType::print(ostream &os) const {
1181-
if (elements)
1182-
os << '<' << elements << " x " << *children[0] << '>';
1184+
if (!elements)
1185+
return;
1186+
os << '<';
1187+
if (scalable) {
1188+
os << "vscale" << ":" << util::config::vscale_value << " x ";
1189+
}
1190+
os << min_elements << " x " << *children[0] << '>';
11831191
}
11841192

11851193

ir/type.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
#include "ir/attrs.h"
77
#include "smt/expr.h"
8+
#include "util/config.h"
89

910
#include <functional>
1011
#include <memory>
@@ -335,9 +336,13 @@ class ArrayType final : public AggregateType {
335336

336337

337338
class VectorType final : public AggregateType {
339+
private:
340+
bool scalable = false;
341+
unsigned min_elements;
342+
338343
public:
339344
VectorType(std::string &&name) : AggregateType(std::move(name)) {}
340-
VectorType(std::string &&name, unsigned elements, Type &elementTy);
345+
VectorType(std::string &&name, unsigned elems, Type &elemTy, bool scalable);
341346

342347
IR::StateValue extract(const IR::StateValue &vector,
343348
const smt::expr &index) const;

llvm_util/cmd_args_def.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ config::debug = opt_debug;
2424
config::quiet = opt_quiet;
2525
config::max_offset_bits = opt_max_offset_in_bits;
2626
config::max_sizet_bits = opt_max_sizet_in_bits;
27+
config::vscale_value = opt_vscale;
2728

2829
if ((config::disallow_ub_exploitation = opt_disallow_ub_exploitation)) {
2930
config::disable_undef_input = true;

llvm_util/cmd_args_list.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,4 +185,9 @@ llvm::cl::opt<bool> opt_disallow_ub_exploitation(
185185
llvm::cl::desc("Disallow UB exploitation by optimizations (default=allow)"),
186186
llvm::cl::init(false), llvm::cl::cat(alive_cmdargs));
187187

188+
llvm::cl::opt<unsigned> opt_vscale(LLVM_ARGS_PREFIX "vscale",
189+
llvm::cl::desc("Set vscale value for scalable vectors (default=2)"),
190+
llvm::cl::init(2), llvm::cl::value_desc("value"),
191+
llvm::cl::cat(alive_cmdargs));
192+
188193
}

llvm_util/llvm2alive.cpp

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1179,6 +1179,13 @@ class llvm2alive_ : public llvm::InstVisitor<llvm2alive_, unique_ptr<Instr>> {
11791179
PARSE_BINOP();
11801180
return make_unique<VaCopy>(*a, *b);
11811181
}
1182+
case llvm::Intrinsic::vscale: {
1183+
auto ty = llvm_type2alive(i.getType());
1184+
if (!ty)
1185+
return error(i);
1186+
auto val = make_intconst(config::vscale_value, ty->bits());
1187+
return make_unique<UnaryOp>(*ty, value_name(i), *val, UnaryOp::Copy);
1188+
}
11821189

11831190
// do nothing intrinsics
11841191
case llvm::Intrinsic::dbg_declare:
@@ -1287,8 +1294,17 @@ class llvm2alive_ : public llvm::InstVisitor<llvm2alive_, unique_ptr<Instr>> {
12871294
RetTy visitShuffleVectorInst(llvm::ShuffleVectorInst &i) {
12881295
PARSE_BINOP();
12891296
vector<unsigned> mask;
1290-
for (auto m : i.getShuffleMask())
1291-
mask.push_back(m);
1297+
1298+
unsigned replicate = 1;
1299+
if (i.getType()->isScalableTy()) {
1300+
replicate = config::vscale_value;
1301+
}
1302+
1303+
auto &&sm = i.getShuffleMask();
1304+
for (unsigned j = 0; j < replicate; j++) {
1305+
mask.insert(mask.end(), sm.begin(), sm.end());
1306+
}
1307+
12921308
return
12931309
make_unique<ShuffleVector>(*ty, value_name(i), *a, *b, std::move(mask));
12941310
}

llvm_util/utils.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,8 @@ Type* llvm_type2alive(const llvm::Type *ty) {
199199
return cache.get();
200200
}
201201
// TODO: non-fixed sized vectors
202-
case llvm::Type::FixedVectorTyID: {
202+
case llvm::Type::FixedVectorTyID:
203+
case llvm::Type::ScalableVectorTyID: {
203204
auto &cache = type_cache[ty];
204205
if (!cache) {
205206
auto vty = cast<llvm::VectorType>(ty);
@@ -208,7 +209,7 @@ Type* llvm_type2alive(const llvm::Type *ty) {
208209
if (!ety || elems > 1024)
209210
return nullptr;
210211
cache = make_unique<VectorType>("ty_" + to_string(type_id_counter++),
211-
elems, *ety);
212+
elems, *ety, vty->isScalableTy());
212213
}
213214
return cache.get();
214215
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
define i32 @src() {
2+
ret i32 1
3+
}
4+
5+
define i32 @tgt() {
6+
entry:
7+
%vs = call i32 @llvm.vscale.i32()
8+
%len = mul i32 %vs, 2
9+
%init_vec = insertelement <vscale x 2 x i32> poison, i32 0, i32 0
10+
br label %for.cond
11+
12+
for.cond:
13+
%i = phi i32 [ 0, %entry ], [ %inc, %for.body ]
14+
%vec = phi <vscale x 2 x i32> [ %init_vec, %entry ], [ %new_vec, %for.body ]
15+
%cmp = icmp slt i32 %i, %len
16+
br i1 %cmp, label %for.body, label %exit
17+
18+
for.body:
19+
%new_vec = insertelement <vscale x 2 x i32> %vec, i32 %i, i32 %i
20+
%inc = add i32 %i, 1
21+
br label %for.cond
22+
23+
exit:
24+
%last_idx = sub i32 %len, 1
25+
%result = extractelement <vscale x 2 x i32> %vec, i32 %last_idx
26+
ret i32 %result
27+
}
28+
29+
; TEST-ARGS: --vscale=1 -tgt-unroll=2
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
define i32 @src() {
2+
ret i32 3
3+
}
4+
define i32 @tgt() {
5+
entry:
6+
%vs = call i32 @llvm.vscale.i32()
7+
%len = mul i32 %vs, 2
8+
%init_vec = insertelement <vscale x 2 x i32> poison, i32 0, i32 0
9+
br label %for.cond
10+
11+
for.cond:
12+
%i = phi i32 [ 0, %entry ], [ %inc, %for.body ]
13+
%vec = phi <vscale x 2 x i32> [ %init_vec, %entry ], [ %new_vec, %for.body ]
14+
%cmp = icmp slt i32 %i, %len
15+
br i1 %cmp, label %for.body, label %exit
16+
17+
for.body:
18+
%new_vec = insertelement <vscale x 2 x i32> %vec, i32 %i, i32 %i
19+
%inc = add i32 %i, 1
20+
br label %for.cond
21+
22+
exit:
23+
%last_idx = sub i32 %len, 1
24+
%result = extractelement <vscale x 2 x i32> %vec, i32 %last_idx
25+
ret i32 %result
26+
}
27+
28+
; TEST-ARGS: --vscale=2 -tgt-unroll=4
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
define i32 @src() {
2+
ret i32 5
3+
}
4+
define i32 @tgt() {
5+
entry:
6+
%vs = call i32 @llvm.vscale.i32()
7+
%len = mul i32 %vs, 2
8+
%init_vec = insertelement <vscale x 2 x i32> poison, i32 0, i32 0
9+
br label %for.cond
10+
11+
for.cond:
12+
%i = phi i32 [ 0, %entry ], [ %inc, %for.body ]
13+
%vec = phi <vscale x 2 x i32> [ %init_vec, %entry ], [ %new_vec, %for.body ]
14+
%cmp = icmp slt i32 %i, %len
15+
br i1 %cmp, label %for.body, label %exit
16+
17+
for.body:
18+
%new_vec = insertelement <vscale x 2 x i32> %vec, i32 %i, i32 %i
19+
%inc = add i32 %i, 1
20+
br label %for.cond
21+
22+
exit:
23+
%last_idx = sub i32 %len, 1
24+
%result = extractelement <vscale x 2 x i32> %vec, i32 %last_idx
25+
ret i32 %result
26+
}
27+
28+
; TEST-ARGS: --vscale=3 -tgt-unroll=6
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
define <vscale x 2 x i32> @src(<vscale x 2 x i32> %vec) {
2+
%insert = insertelement <vscale x 2 x i32> %vec, i32 0, i32 0
3+
%shuf = shufflevector <vscale x 2 x i32> %insert, <vscale x 2 x i32> poison, <vscale x 2 x i32> poison
4+
ret <vscale x 2 x i32> %shuf
5+
}
6+
7+
define <vscale x 2 x i32> @tgt(<vscale x 2 x i32> %vec) {
8+
ret <vscale x 2 x i32> poison
9+
}
10+
11+
; TEST-ARGS: --vscale=1

0 commit comments

Comments
 (0)