33#include < ranges>
44#include < tuple>
55
6- mini_jit::TensorOperation::~TensorOperation ()
7- {
8- cleanup ();
9- }
10-
11- void mini_jit::TensorOperation::cleanup ()
12- {
13- if (isUnary (prim_first))
14- {
15- delete first_touch.unary ;
16- prim_first = prim_t ::none;
17- }
18-
19- if (isUnary (prim_main))
20- {
21- delete main.unary ;
22- prim_main = prim_t ::none;
23- }
24- else if (isBrgemm (prim_main))
25- {
26- delete main.brgemm ;
27- prim_main = prim_t ::none;
28- }
29-
30- if (isUnary (prim_last))
31- {
32- delete last_touch.unary ;
33- prim_last = prim_t ::none;
34- }
35-
36- release_assert (prim_first == prim_t ::none, " Expected prim_first to be none after cleanup." );
37- release_assert (prim_main == prim_t ::none, " Expected prim_main to be none after cleanup." );
38- release_assert (prim_last == prim_t ::none, " Expected prim_last to be none after cleanup." );
39- }
40-
416bool mini_jit::TensorOperation::isUnary (prim_t prim)
427{
438 return prim == prim_t ::copy || prim == prim_t ::relu || prim == prim_t ::relu;
@@ -146,7 +111,7 @@ bool mini_jit::TensorOperation::isExpectedStride(int64_t expected, int index, co
146111 return strides[index] == expected;
147112}
148113
149- mini_jit::Unary::error_t mini_jit::TensorOperation::generateUnary (Unary * unary, prim_t prim, const std::span<const dim_t > &dim_types,
114+ mini_jit::Unary::error_t mini_jit::TensorOperation::generateUnary (Unary & unary, prim_t prim, const std::span<const dim_t > &dim_types,
150115 const std::span<const exec_t > &exec_types,
151116 const std::span<const int64_t > &dim_sizes)
152117{
@@ -172,7 +137,7 @@ mini_jit::Unary::error_t mini_jit::TensorOperation::generateUnary(Unary *unary,
172137 release_assert (false , " Found a invalid type for the unary first touch." );
173138 break ;
174139 }
175- return unary-> generate (dim_sizes[indexPrimM], dim_sizes[indexPrimN], 0 , Unary::dtype_t ::fp32, type);
140+ return unary. generate (dim_sizes[indexPrimM], dim_sizes[indexPrimN], 0 , Unary::dtype_t ::fp32, type);
176141}
177142
178143mini_jit::TensorOperation::error_t mini_jit::TensorOperation::setup (dtype_t dtype, prim_t prim_first_touch, prim_t prim_main,
@@ -182,9 +147,6 @@ mini_jit::TensorOperation::error_t mini_jit::TensorOperation::setup(dtype_t dtyp
182147 std::span<const int64_t > strides_in1,
183148 std::span<const int64_t > strides_out)
184149{
185- // clear all old used resources
186- cleanup ();
187-
188150 TensorOperation::prim_first = prim_t ::none; // Not yet generated, correctness of cleanup
189151 TensorOperation::prim_main = prim_t ::none; // Not yet generated, correctness of cleanup
190152 TensorOperation::prim_last = prim_t ::none; // Not yet generated, correctness of cleanup
@@ -236,10 +198,10 @@ mini_jit::TensorOperation::error_t mini_jit::TensorOperation::setup(dtype_t dtyp
236198 {
237199 if (prim_first_touch == prim_t ::zero || prim_first_touch == prim_t ::copy || prim_first_touch == prim_t ::relu)
238200 {
239- first_touch.unary = new Unary ();
201+ first_touch.emplace < Unary> ();
240202 TensorOperation::prim_first = prim_first_touch;
241203
242- Unary::error_t error = generateUnary (first_touch. unary , prim_first_touch, dim_types, exec_types, dim_sizes);
204+ Unary::error_t error = generateUnary (std::get<Unary>( first_touch) , prim_first_touch, dim_types, exec_types, dim_sizes);
243205
244206 if (error != Unary::error_t ::success)
245207 {
@@ -256,21 +218,22 @@ mini_jit::TensorOperation::error_t mini_jit::TensorOperation::setup(dtype_t dtyp
256218 {
257219 if (isBrgemm (prim_main))
258220 {
259- main.brgemm = new Brgemm ();
221+ main.emplace < Brgemm> ();
260222 TensorOperation::prim_main = prim_main;
261223
262224 if (prim_main == prim_t ::brgemm)
263225 {
264226 indexPrimBatch = findMatch (dim_types, exec_types, dim_t ::k, exec_t ::prim);
265227 indexPrimK = findMatch (dim_types, exec_types, dim_t ::k, exec_t ::prim, indexPrimBatch);
266228
267- main. brgemm -> generate (dim_sizes[indexPrimM], dim_sizes[indexPrimN], dim_sizes[indexPrimK], dim_sizes[indexPrimBatch], 0 , 0 , 0 ,
268- Brgemm::dtype_t ::fp32);
229+ std::get<Brgemm>( main). generate (dim_sizes[indexPrimM], dim_sizes[indexPrimN], dim_sizes[indexPrimK], dim_sizes[indexPrimBatch], 0 ,
230+ 0 , 0 , Brgemm::dtype_t ::fp32);
269231 }
270232 else if (prim_main == prim_t ::gemm)
271233 {
272234 indexPrimK = findMatch (dim_types, exec_types, dim_t ::k, exec_t ::prim);
273- main.brgemm ->generate (dim_sizes[indexPrimM], dim_sizes[indexPrimN], dim_sizes[indexPrimK], 1 , 0 , 0 , 0 , Brgemm::dtype_t ::fp32);
235+ std::get<Brgemm>(main).generate (dim_sizes[indexPrimM], dim_sizes[indexPrimN], dim_sizes[indexPrimK], 1 , 0 , 0 , 0 ,
236+ Brgemm::dtype_t ::fp32);
274237 }
275238 else
276239 {
@@ -279,10 +242,9 @@ mini_jit::TensorOperation::error_t mini_jit::TensorOperation::setup(dtype_t dtyp
279242 }
280243 else if (isUnary (prim_main))
281244 {
282- main.unary = new Unary ();
283- TensorOperation::prim_main = prim_main;
245+ main.emplace <Unary>();
284246
285- Unary::error_t error = generateUnary (main. unary , prim_main, dim_types, exec_types, dim_sizes);
247+ Unary::error_t error = generateUnary (std::get<Unary>( main) , prim_main, dim_types, exec_types, dim_sizes);
286248
287249 if (error != Unary::error_t ::success)
288250 {
@@ -299,10 +261,10 @@ mini_jit::TensorOperation::error_t mini_jit::TensorOperation::setup(dtype_t dtyp
299261 {
300262 if (isUnary (prim_last_touch))
301263 {
302- last_touch.unary = new Unary ();
264+ last_touch.emplace < Unary> ();
303265 TensorOperation::prim_last = prim_last_touch;
304266
305- Unary::error_t error = generateUnary (last_touch. unary , prim_last_touch, dim_types, exec_types, dim_sizes);
267+ Unary::error_t error = generateUnary (std::get<Unary>( last_touch) , prim_last_touch, dim_types, exec_types, dim_sizes);
306268
307269 if (error != Unary::error_t ::success)
308270 {
0 commit comments