Skip to content

Commit 2829ddf

Browse files
committed
feat: unary main primitive + transpose
1 parent 33d8076 commit 2829ddf

File tree

9 files changed

+404
-45
lines changed

9 files changed

+404
-45
lines changed

.vscode/settings.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,5 +131,5 @@
131131
"editor.rulers": [140],
132132

133133
"editor.insertSpaces": true,
134-
"editor.tabSize": 2
134+
"editor.tabSize": 2,
135135
}

Testing/Temporary/CTestCostData.txt

Lines changed: 0 additions & 1 deletion
This file was deleted.

Testing/Temporary/LastTest.log

Lines changed: 0 additions & 3 deletions
This file was deleted.

src/main/TensorOperation.cpp

Lines changed: 123 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -125,19 +125,13 @@ int32_t mini_jit::TensorOperation::findMatch(const std::span<const TensorConfig:
125125
}
126126

127127
bool mini_jit::TensorOperation::isValidPrimConfig(const std::span<const TensorConfig::dim_t> &dim,
128-
const std::span<const TensorConfig::exec_t> &exec,
129-
const std::span<const int64_t> &strides_in0, const std::span<const int64_t> &strides_out)
128+
const std::span<const TensorConfig::exec_t> &exec)
130129
{
131130
int32_t indexM = findMatch(dim, exec, TensorConfig::dim_t::m, TensorConfig::exec_t::prim);
132131
int32_t indexN = findMatch(dim, exec, TensorConfig::dim_t::n, TensorConfig::exec_t::prim);
133132
if (indexM == -1 || indexN == -1)
134133
{
135-
std::cerr << "1: Could not find a matching index: indexM:" << indexM << ", indexN:" << indexN << std::endl;
136-
return false;
137-
}
138-
139-
if (!(isExpectedStride(1, indexM, strides_in0) && isExpectedStride(1, indexM, strides_out)))
140-
{
134+
std::cerr << "isValidPrimConfig 1: Could not find a matching index: indexM:" << indexM << ", indexN:" << indexN << std::endl;
141135
return false;
142136
}
143137

@@ -146,16 +140,48 @@ bool mini_jit::TensorOperation::isValidPrimConfig(const std::span<const TensorCo
146140
indexN = findMatch(dim, exec, TensorConfig::dim_t::n, TensorConfig::exec_t::prim, indexN + 1);
147141
if (indexM != -1 || indexN != -1)
148142
{
149-
std::cerr << "2: Could not find a matching index: indexM:" << indexM << ", indexN" << indexN << std::endl;
143+
std::cerr << "isValidPrimConfig 2: Could not find a matching index: indexM:" << indexM << ", indexN" << indexN << std::endl;
150144
return false;
151145
}
152146

153147
return true;
154148
}
155149

150+
bool mini_jit::TensorOperation::isValidPrimStrides(const std::span<const TensorConfig::dim_t> &dim,
151+
const std::span<const TensorConfig::exec_t> &exec,
152+
const std::span<const int64_t> &strides_in0, const std::span<const int64_t> &strides_out,
153+
const TensorConfig::prim_t main_prim)
154+
{
155+
int32_t indexM = findMatch(dim, exec, TensorConfig::dim_t::m, TensorConfig::exec_t::prim);
156+
int32_t indexN = findMatch(dim, exec, TensorConfig::dim_t::n, TensorConfig::exec_t::prim);
157+
if (indexM == -1 || indexN == -1)
158+
{
159+
std::cerr << "isValidStride: Could not find a matching index: indexM:" << indexM << ", indexN:" << indexN << std::endl;
160+
return false;
161+
}
162+
163+
// stride of m = 1 and stride of n = 1 i.e. no transpose
164+
if (isExpectedStride(1, indexM, strides_in0) && isExpectedStride(1, indexM, strides_out))
165+
{
166+
return true;
167+
}
168+
169+
// Check transpose in unary op
170+
if (isUnary(main_prim) && isExpectedStride(1, indexM, strides_in0) && isExpectedStride(1, indexN, strides_out))
171+
{
172+
isTranspose = true;
173+
return true;
174+
}
175+
176+
std::cerr << "isValidStride: Could not find a valid stride: in0: m-stride: " << strides_in0[indexM]
177+
<< ", n-stride: " << strides_in0[indexN] << "; out: m-stride: " << strides_out[indexM] << ", n-stride: " << strides_out[indexN]
178+
<< std::endl;
179+
return false;
180+
}
181+
156182
bool mini_jit::TensorOperation::isValidKDim(const std::span<const TensorConfig::dim_t> &dim,
157183
const std::span<const TensorConfig::exec_t> &exec, const std::span<const int64_t> &strides_in1,
158-
TensorConfig::prim_t prim)
184+
const TensorConfig::prim_t prim)
159185
{
160186
if (isBrgemm(prim))
161187
{
@@ -249,23 +275,77 @@ bool mini_jit::TensorOperation::isValidStride(const std::span<const TensorConfig
249275
switch (strideType)
250276
{
251277
case stride_t::in0:
252-
if (*iDim == TensorConfig::dim_t::n && *iStride != 0)
278+
switch (*iDim)
253279
{
254-
return false;
280+
case TensorConfig::dim_t::c:
281+
case TensorConfig::dim_t::m:
282+
case TensorConfig::dim_t::k:
283+
if (*iStride == 0)
284+
{
285+
return false;
286+
}
287+
break;
288+
289+
case TensorConfig::dim_t::n:
290+
if (*iStride != 0)
291+
{
292+
return false;
293+
}
294+
break;
295+
296+
default:
297+
release_assert(false, "Found unhandled dimension type.");
298+
break;
255299
}
256300
break;
257301

258302
case stride_t::in1:
259-
if (*iDim == TensorConfig::dim_t::m && *iStride != 0)
303+
switch (*iDim)
260304
{
261-
return false;
305+
case TensorConfig::dim_t::c:
306+
case TensorConfig::dim_t::n:
307+
case TensorConfig::dim_t::k:
308+
if (*iStride == 0)
309+
{
310+
return false;
311+
}
312+
break;
313+
314+
case TensorConfig::dim_t::m:
315+
if (*iStride != 0)
316+
{
317+
return false;
318+
}
319+
break;
320+
321+
default:
322+
release_assert(false, "Found unhandled dimension type.");
323+
break;
262324
}
263325
break;
264326

265327
case stride_t::out:
266-
if (*iDim == TensorConfig::dim_t::k && *iStride != 0)
328+
switch (*iDim)
267329
{
268-
return false;
330+
case TensorConfig::dim_t::c:
331+
case TensorConfig::dim_t::n:
332+
case TensorConfig::dim_t::m:
333+
if (*iStride == 0)
334+
{
335+
return false;
336+
}
337+
break;
338+
339+
case TensorConfig::dim_t::k:
340+
if (*iStride != 0)
341+
{
342+
return false;
343+
}
344+
break;
345+
346+
default:
347+
release_assert(false, "Found unhandled dimension type.");
348+
break;
269349
}
270350
break;
271351

@@ -279,7 +359,7 @@ bool mini_jit::TensorOperation::isValidStride(const std::span<const TensorConfig
279359
}
280360

281361
mini_jit::Unary::error_t mini_jit::TensorOperation::generateUnary(Unary &unary, TensorConfig::prim_t prim,
282-
const std::span<const int64_t> &dim_sizes)
362+
const std::span<const int64_t> &dim_sizes, bool isTranspose)
283363
{
284364
release_assert(indexPrimM != -1, "Expected a match for the m primitive dimension");
285365
release_assert(indexPrimN != -1, "Expected a match for the n primitive dimension");
@@ -303,7 +383,8 @@ mini_jit::Unary::error_t mini_jit::TensorOperation::generateUnary(Unary &unary,
303383
release_assert(false, "Found a invalid type for the unary first touch.");
304384
break;
305385
}
306-
return unary.generate(dim_sizes[indexPrimM], dim_sizes[indexPrimN], 0, Unary::dtype_t::fp32, type);
386+
387+
return unary.generate(dim_sizes[indexPrimM], dim_sizes[indexPrimN], isTranspose, Unary::dtype_t::fp32, type);
307388
}
308389

309390
mini_jit::TensorOperation::error_t mini_jit::TensorOperation::setup(const TensorConfig &config)
@@ -322,7 +403,10 @@ mini_jit::TensorOperation::error_t mini_jit::TensorOperation::setup_no_optimizat
322403
std::span<const TensorConfig::dim_t> dim_types, std::span<const TensorConfig::exec_t> exec_types, std::span<const int64_t> dim_sizes,
323404
std::span<const int64_t> strides_in0, std::span<const int64_t> strides_in1, std::span<const int64_t> strides_out)
324405
{
406+
// Reset to defaults
325407
hasSetupError = false;
408+
isParallel = false;
409+
isTranspose = false;
326410
indexPrimBatch = -1;
327411
indexPrimK = -1;
328412
indexPrimM = -1;
@@ -367,6 +451,7 @@ mini_jit::TensorOperation::error_t mini_jit::TensorOperation::setup_no_optimizat
367451
if (kDimExecType != -1)
368452
{
369453
hasSetupError = true;
454+
std::cerr << "Error: Found k dimension tagged as shared, but can not execute k dimension as shared." << std::endl;
370455
return error_t::err_k_dimension_must_not_be_shared;
371456
}
372457
}
@@ -387,14 +472,21 @@ mini_jit::TensorOperation::error_t mini_jit::TensorOperation::setup_no_optimizat
387472
return error_t::err_invalid_execution_order;
388473
}
389474

390-
if (!isValidPrimConfig(dim_types, exec_types, strides_in0, strides_out))
475+
if (!isValidPrimConfig(dim_types, exec_types))
391476
{
392477
hasSetupError = true;
393478
std::cerr << "Error: Invalid primitive configuration detected. Expected one primitive for m and one primitive for n to exist"
394479
<< std::endl;
395480
return error_t::err_invalid_primitive_configuration;
396481
}
397482

483+
if (!isValidPrimStrides(dim_types, exec_types, strides_in0, strides_out, prim_main))
484+
{
485+
hasSetupError = true;
486+
std::cerr << "Error: Invalid strides for the primitive m dimension (or n dimension if transpose)." << std::endl;
487+
return error_t::err_invalid_strides;
488+
}
489+
398490
if (!isValidKDim(dim_types, exec_types, strides_in1, prim_main))
399491
{
400492
hasSetupError = true;
@@ -412,6 +504,13 @@ mini_jit::TensorOperation::error_t mini_jit::TensorOperation::setup_no_optimizat
412504
std::cerr << "Error: Invalid stride configuration detected for unary. Expected k-dimension to have a stride of zero." << std::endl;
413505
return error_t::err_invalid_strides;
414506
}
507+
508+
if (prim_last_touch != TensorConfig::prim_t::none || prim_last_touch != TensorConfig::prim_t::none)
509+
{
510+
hasSetupError = true;
511+
std::cerr << "Error: A main 'Unary' primitive can not have first touch and last touch primitives." << std::endl;
512+
return error_t::err_invalid_main_configuration;
513+
}
415514
}
416515
else if (isBrgemm(prim_main))
417516
{
@@ -448,7 +547,7 @@ mini_jit::TensorOperation::error_t mini_jit::TensorOperation::setup_no_optimizat
448547
first_touch.emplace<Unary>();
449548
TensorOperation::prim_first = prim_first_touch;
450549

451-
Unary::error_t error = generateUnary(std::get<Unary>(first_touch), prim_first_touch, dim_sizes);
550+
Unary::error_t error = generateUnary(std::get<Unary>(first_touch), prim_first_touch, dim_sizes, false);
452551

453552
if (error != Unary::error_t::success)
454553
{
@@ -517,7 +616,7 @@ mini_jit::TensorOperation::error_t mini_jit::TensorOperation::setup_no_optimizat
517616
main_kernel.emplace<Unary>();
518617
TensorOperation::prim_main = prim_main;
519618

520-
Unary::error_t error = generateUnary(std::get<Unary>(main_kernel), prim_main, dim_sizes);
619+
Unary::error_t error = generateUnary(std::get<Unary>(main_kernel), prim_main, dim_sizes, isTranspose);
521620

522621
if (error != Unary::error_t::success)
523622
{
@@ -541,7 +640,7 @@ mini_jit::TensorOperation::error_t mini_jit::TensorOperation::setup_no_optimizat
541640
last_touch.emplace<Unary>();
542641
TensorOperation::prim_last = prim_last_touch;
543642

544-
Unary::error_t error = generateUnary(std::get<Unary>(last_touch), prim_last_touch, dim_sizes);
643+
Unary::error_t error = generateUnary(std::get<Unary>(last_touch), prim_last_touch, dim_sizes, false);
545644

546645
if (error != Unary::error_t::success)
547646
{
@@ -663,7 +762,8 @@ void mini_jit::TensorOperation::execute_dimension(int64_t index_dim, char const
663762
if (std::holds_alternative<Unary>(main_kernel))
664763
{
665764
Unary::kernel_t kernel = std::get<Unary>(main_kernel).get_kernel();
666-
kernel(ptr_in0, ptr_out, strides_in0[indexPrimN], strides_out[indexPrimN]);
765+
int32_t indexLeadingDimension = isTranspose ? indexPrimM : indexPrimN;
766+
kernel(ptr_in0, ptr_out, strides_in0[indexPrimN], strides_out[indexLeadingDimension]);
667767
}
668768
else if (std::holds_alternative<Brgemm>(main_kernel))
669769
{

src/main/TensorOperation.h

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ namespace mini_jit
7171

7272
bool isParallel = false; // default is sequential execution
7373

74+
bool isTranspose = false; // default is no transpose
75+
7476
bool hasSetupError = false;
7577

7678
/**
@@ -81,21 +83,33 @@ namespace mini_jit
8183
* @return true The configuration is a valid primitive setup.
8284
* @return false The configuration is NOT a valid primitive setup.
8385
*/
84-
bool isValidPrimConfig(const std::span<const TensorConfig::dim_t> &dim, const std::span<const TensorConfig::exec_t> &exec,
85-
const std::span<const int64_t> &strides_in0, const std::span<const int64_t> &strides_out);
86+
bool isValidPrimConfig(const std::span<const TensorConfig::dim_t> &dim, const std::span<const TensorConfig::exec_t> &exec);
87+
88+
/**
89+
* @brief Validates that the strides of the m primitives and n primitives dimension are unit strides.
90+
*
91+
* @param dim The dimension types to search through.
92+
* @param exec The execution types to search through.
93+
* @param prim_main The main primitive of the tensor operation.
94+
* @return true The configuration has valid strides.
95+
* @return false The configuration does not have valid unit strides.
96+
*/
97+
bool isValidPrimStrides(const std::span<const TensorConfig::dim_t> &dim, const std::span<const TensorConfig::exec_t> &exec,
98+
const std::span<const int64_t> &strides_in0, const std::span<const int64_t> &strides_out,
99+
const TensorConfig::prim_t main_prim);
86100

87101
/**
88102
* @brief Checks if the K dimension is valid for the given primitive.
89103
*
90104
* @param dim The dimension types to search through.
91105
* @param exec The execution types to search through.
92-
* @param
106+
* @param strides_in1 The strides of the second input.
93107
* @param prim The primitive i.e. Gemm or Brgemm to be executed.
94108
* @return true The configuration is a valid setup.
95109
* @return false The configuration is NOT a valid setup.
96110
*/
97111
bool isValidKDim(const std::span<const TensorConfig::dim_t> &dim, const std::span<const TensorConfig::exec_t> &exec,
98-
const std::span<const int64_t> &strides_in1, TensorConfig::prim_t prim);
112+
const std::span<const int64_t> &strides_in1, const TensorConfig::prim_t prim);
99113

100114
/**
101115
* @brief Checks if the configuration is sorted such that the primitives are last.
@@ -112,9 +126,10 @@ namespace mini_jit
112126
* @param unary The unary used for generation.
113127
* @param prim The primitive that is generated.
114128
* @param dim_sizes The sizes of each dimension.
129+
* @param isTranspose Indicates if the unary is executes a tranpose operation.
115130
* @return Unary::error_t
116131
*/
117-
Unary::error_t generateUnary(Unary &unary, TensorConfig::prim_t prim, const std::span<const int64_t> &dim_sizes);
132+
Unary::error_t generateUnary(Unary &unary, TensorConfig::prim_t prim, const std::span<const int64_t> &dim_sizes, bool isTranspose);
118133

119134
public:
120135
/**
@@ -221,11 +236,10 @@ namespace mini_jit
221236
void execute_dimension(int64_t index_dimension, char const *ptr_in0, char const *ptr_in1, char *ptr_out, bool first_access,
222237
bool last_access);
223238

224-
225239
/**
226240
* @brief Get the current configuration object.
227-
*
228-
* @return TensorConfig used by the Tensor operation.
241+
*
242+
* @return TensorConfig used by the Tensor operation.
229243
*/
230244
TensorConfig get_config();
231245
};

0 commit comments

Comments
 (0)