Skip to content

Commit 8702ef4

Browse files
committed
refactor: use variant instead of union
1 parent 6138177 commit 8702ef4

File tree

2 files changed

+18
-68
lines changed

2 files changed

+18
-68
lines changed

src/main/TensorOperation.cpp

Lines changed: 13 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -3,41 +3,6 @@
33
#include <ranges>
44
#include <tuple>
55

6-
mini_jit::TensorOperation::~TensorOperation()
7-
{
8-
cleanup();
9-
}
10-
11-
void mini_jit::TensorOperation::cleanup()
12-
{
13-
if (isUnary(prim_first))
14-
{
15-
delete first_touch.unary;
16-
prim_first = prim_t::none;
17-
}
18-
19-
if (isUnary(prim_main))
20-
{
21-
delete main.unary;
22-
prim_main = prim_t::none;
23-
}
24-
else if (isBrgemm(prim_main))
25-
{
26-
delete main.brgemm;
27-
prim_main = prim_t::none;
28-
}
29-
30-
if (isUnary(prim_last))
31-
{
32-
delete last_touch.unary;
33-
prim_last = prim_t::none;
34-
}
35-
36-
release_assert(prim_first == prim_t::none, "Expected prim_first to be none after cleanup.");
37-
release_assert(prim_main == prim_t::none, "Expected prim_main to be none after cleanup.");
38-
release_assert(prim_last == prim_t::none, "Expected prim_last to be none after cleanup.");
39-
}
40-
416
bool mini_jit::TensorOperation::isUnary(prim_t prim)
427
{
438
return prim == prim_t::copy || prim == prim_t::relu || prim == prim_t::relu;
@@ -146,7 +111,7 @@ bool mini_jit::TensorOperation::isExpectedStride(int64_t expected, int index, co
146111
return strides[index] == expected;
147112
}
148113

149-
mini_jit::Unary::error_t mini_jit::TensorOperation::generateUnary(Unary *unary, prim_t prim, const std::span<const dim_t> &dim_types,
114+
mini_jit::Unary::error_t mini_jit::TensorOperation::generateUnary(Unary &unary, prim_t prim, const std::span<const dim_t> &dim_types,
150115
const std::span<const exec_t> &exec_types,
151116
const std::span<const int64_t> &dim_sizes)
152117
{
@@ -172,7 +137,7 @@ mini_jit::Unary::error_t mini_jit::TensorOperation::generateUnary(Unary *unary,
172137
release_assert(false, "Found a invalid type for the unary first touch.");
173138
break;
174139
}
175-
return unary->generate(dim_sizes[indexPrimM], dim_sizes[indexPrimN], 0, Unary::dtype_t::fp32, type);
140+
return unary.generate(dim_sizes[indexPrimM], dim_sizes[indexPrimN], 0, Unary::dtype_t::fp32, type);
176141
}
177142

178143
mini_jit::TensorOperation::error_t mini_jit::TensorOperation::setup(dtype_t dtype, prim_t prim_first_touch, prim_t prim_main,
@@ -182,9 +147,6 @@ mini_jit::TensorOperation::error_t mini_jit::TensorOperation::setup(dtype_t dtyp
182147
std::span<const int64_t> strides_in1,
183148
std::span<const int64_t> strides_out)
184149
{
185-
// clear all old used resources
186-
cleanup();
187-
188150
TensorOperation::prim_first = prim_t::none; // Not yet generated, correctness of cleanup
189151
TensorOperation::prim_main = prim_t::none; // Not yet generated, correctness of cleanup
190152
TensorOperation::prim_last = prim_t::none; // Not yet generated, correctness of cleanup
@@ -236,10 +198,10 @@ mini_jit::TensorOperation::error_t mini_jit::TensorOperation::setup(dtype_t dtyp
236198
{
237199
if (prim_first_touch == prim_t::zero || prim_first_touch == prim_t::copy || prim_first_touch == prim_t::relu)
238200
{
239-
first_touch.unary = new Unary();
201+
first_touch.emplace<Unary>();
240202
TensorOperation::prim_first = prim_first_touch;
241203

242-
Unary::error_t error = generateUnary(first_touch.unary, prim_first_touch, dim_types, exec_types, dim_sizes);
204+
Unary::error_t error = generateUnary(std::get<Unary>(first_touch), prim_first_touch, dim_types, exec_types, dim_sizes);
243205

244206
if (error != Unary::error_t::success)
245207
{
@@ -256,21 +218,22 @@ mini_jit::TensorOperation::error_t mini_jit::TensorOperation::setup(dtype_t dtyp
256218
{
257219
if (isBrgemm(prim_main))
258220
{
259-
main.brgemm = new Brgemm();
221+
main.emplace<Brgemm>();
260222
TensorOperation::prim_main = prim_main;
261223

262224
if (prim_main == prim_t::brgemm)
263225
{
264226
indexPrimBatch = findMatch(dim_types, exec_types, dim_t::k, exec_t::prim);
265227
indexPrimK = findMatch(dim_types, exec_types, dim_t::k, exec_t::prim, indexPrimBatch);
266228

267-
main.brgemm->generate(dim_sizes[indexPrimM], dim_sizes[indexPrimN], dim_sizes[indexPrimK], dim_sizes[indexPrimBatch], 0, 0, 0,
268-
Brgemm::dtype_t::fp32);
229+
std::get<Brgemm>(main).generate(dim_sizes[indexPrimM], dim_sizes[indexPrimN], dim_sizes[indexPrimK], dim_sizes[indexPrimBatch], 0,
230+
0, 0, Brgemm::dtype_t::fp32);
269231
}
270232
else if (prim_main == prim_t::gemm)
271233
{
272234
indexPrimK = findMatch(dim_types, exec_types, dim_t::k, exec_t::prim);
273-
main.brgemm->generate(dim_sizes[indexPrimM], dim_sizes[indexPrimN], dim_sizes[indexPrimK], 1, 0, 0, 0, Brgemm::dtype_t::fp32);
235+
std::get<Brgemm>(main).generate(dim_sizes[indexPrimM], dim_sizes[indexPrimN], dim_sizes[indexPrimK], 1, 0, 0, 0,
236+
Brgemm::dtype_t::fp32);
274237
}
275238
else
276239
{
@@ -279,10 +242,9 @@ mini_jit::TensorOperation::error_t mini_jit::TensorOperation::setup(dtype_t dtyp
279242
}
280243
else if (isUnary(prim_main))
281244
{
282-
main.unary = new Unary();
283-
TensorOperation::prim_main = prim_main;
245+
main.emplace<Unary>();
284246

285-
Unary::error_t error = generateUnary(main.unary, prim_main, dim_types, exec_types, dim_sizes);
247+
Unary::error_t error = generateUnary(std::get<Unary>(main), prim_main, dim_types, exec_types, dim_sizes);
286248

287249
if (error != Unary::error_t::success)
288250
{
@@ -299,10 +261,10 @@ mini_jit::TensorOperation::error_t mini_jit::TensorOperation::setup(dtype_t dtyp
299261
{
300262
if (isUnary(prim_last_touch))
301263
{
302-
last_touch.unary = new Unary();
264+
last_touch.emplace<Unary>();
303265
TensorOperation::prim_last = prim_last_touch;
304266

305-
Unary::error_t error = generateUnary(last_touch.unary, prim_last_touch, dim_types, exec_types, dim_sizes);
267+
Unary::error_t error = generateUnary(std::get<Unary>(last_touch), prim_last_touch, dim_types, exec_types, dim_sizes);
306268

307269
if (error != Unary::error_t::success)
308270
{

src/main/TensorOperation.h

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "Unary.h"
66
#include <cstdint>
77
#include <span>
8+
#include <variant>
89
#include <vector>
910

1011
namespace mini_jit
@@ -88,22 +89,9 @@ class mini_jit::TensorOperation
8889
int32_t indexPrimK = -1;
8990
int32_t indexPrimBatch = -1;
9091

91-
union Primitive
92-
{
93-
mini_jit::Unary *unary;
94-
mini_jit::Brgemm *brgemm;
95-
};
96-
97-
Primitive first_touch;
98-
Primitive main;
99-
Primitive last_touch;
100-
101-
~TensorOperation();
102-
103-
/**
104-
* @brief Cleans up the current set primitives.
105-
*/
106-
void cleanup();
92+
std::variant<mini_jit::Brgemm, mini_jit::Unary> first_touch;
93+
std::variant<mini_jit::Brgemm, mini_jit::Unary> main;
94+
std::variant<mini_jit::Brgemm, mini_jit::Unary> last_touch;
10795

10896
/**
10997
* @brief Indicates if a primitive fits the Unary generator.
@@ -162,7 +150,7 @@ class mini_jit::TensorOperation
162150

163151
static bool isExpectedStride(int64_t expected, int index, const std::span<const int64_t> strides);
164152

165-
Unary::error_t generateUnary(Unary *unary, prim_t prim, const std::span<const dim_t> &dim_types,
153+
Unary::error_t generateUnary(Unary &unary, prim_t prim, const std::span<const dim_t> &dim_types,
166154
const std::span<const exec_t> &exec_types, const std::span<const int64_t> &dim_sizes);
167155

168156
public:

0 commit comments

Comments
 (0)