Skip to content

Commit c187f97

Browse files
committed
chore: added main unary test + better error messages
1 parent 4565d96 commit c187f97

File tree

2 files changed

+122
-94
lines changed

2 files changed

+122
-94
lines changed

src/main/TensorOperation.cpp

Lines changed: 42 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,8 @@ mini_jit::TensorOperation::error_t mini_jit::TensorOperation::setup(dtype_t dtyp
303303
if (dim_sizes.size() != dim_types.size() || dim_sizes.empty() || dim_types.empty())
304304
{
305305
hasSetupError = true;
306-
std::cerr << "Error: Dimension sizes and types must match and cannot be empty." << std::endl;
306+
std::cerr << "Error: Dimension sizes and types must match and cannot be empty, but got dim_sizes: " << dim_sizes.size() << ", dim_types"
307+
<< dim_types.size() << std::endl;
307308
return error_t::err_wrong_dimension;
308309
}
309310

@@ -314,7 +315,9 @@ mini_jit::TensorOperation::error_t mini_jit::TensorOperation::setup(dtype_t dtyp
314315
(isUnary(prim_last_touch) || prim_last_touch == prim_t::none) && strides_in1.empty()))))
315316
{
316317
hasSetupError = true;
317-
std::cerr << "Error: Strides must match the number of dimensions." << std::endl;
318+
std::cerr << "Error: Strides must match the number of dimensions, but got dim_sizes: " << dim_sizes.size()
319+
<< ", strides_in0: " << strides_in0.size() << ", strides_in1: " << strides_in1.size()
320+
<< ", strides_out:" << strides_out.size() << std::endl;
318321
return error_t::err_wrong_dimension; // Strides must match the number of dimensions
319322
}
320323

@@ -331,26 +334,31 @@ mini_jit::TensorOperation::error_t mini_jit::TensorOperation::setup(dtype_t dtyp
331334
if (dtype != dtype_t::fp32)
332335
{
333336
hasSetupError = true;
337+
std::cerr << "Error: data type must be fp32, but got " << static_cast<uint32_t>(dtype) << std::endl;
334338
return error_t::err_wrong_dtype;
335339
}
336340

337341
if (!isSortedConfiguration(exec_types))
338342
{
339343
hasSetupError = true;
344+
std::cerr << "Error: Expected the execution types to be sorted in the order: (shared*, sequential*, primitive*)" << std::endl;
340345
return error_t::err_invalid_execution_order;
341346
}
342347

343348
if (!isValidPrimConfig(dim_types, exec_types, strides_in0, strides_out))
344349
{
345350
hasSetupError = true;
346-
std::cerr << "1: Invalid primitive configuration detected" << std::endl;
351+
std::cerr << "Error: Invalid primitive configuration detected. Expected one primitive for m and one primitive for n to exist"
352+
<< std::endl;
347353
return error_t::err_invalid_primitive_configuration;
348354
}
349355

350356
if (!isValidKDim(dim_types, exec_types, strides_in1, prim_main))
351357
{
352358
hasSetupError = true;
353-
std::cerr << "2: Invalid primitive configuration detected" << std::endl;
359+
std::cerr << "Error: Invalid primitive configuration detected. Expected to find zero primitive k dimension for unary, one primitive k "
360+
"dimension for gemm, two primitive k dimension."
361+
<< std::endl;
354362
return error_t::err_invalid_primitive_configuration;
355363
}
356364

@@ -359,7 +367,7 @@ mini_jit::TensorOperation::error_t mini_jit::TensorOperation::setup(dtype_t dtyp
359367
if (!isValidStride(dim_types, strides_in0, stride_t::out) || !isValidStride(dim_types, strides_out, stride_t::out))
360368
{
361369
hasSetupError = true;
362-
std::cerr << "3: Invalid stride configuration detected for unary" << std::endl;
370+
std::cerr << "Error: Invalid stride configuration detected for unary. Expected k-dimension to have a stride of zero." << std::endl;
363371
return error_t::err_invalid_strides;
364372
}
365373
}
@@ -369,7 +377,9 @@ mini_jit::TensorOperation::error_t mini_jit::TensorOperation::setup(dtype_t dtyp
369377
!isValidStride(dim_types, strides_out, stride_t::out))
370378
{
371379
hasSetupError = true;
372-
std::cerr << "3: Invalid stride configuration detected for brgemm" << std::endl;
380+
std::cerr << "Error: Invalid stride configuration detected for brgemm. Expected for in0 to have n-dimension stride of zero, in1 to "
381+
"have m-dimension stride of zero and out to have k-dimension stride of zero."
382+
<< std::endl;
373383
return error_t::err_invalid_strides;
374384
}
375385
}
@@ -401,12 +411,14 @@ mini_jit::TensorOperation::error_t mini_jit::TensorOperation::setup(dtype_t dtyp
401411
if (error != Unary::error_t::success)
402412
{
403413
hasSetupError = true;
414+
std::cerr << "Error: while generating the first touch unary: " << static_cast<uint32_t>(error) << std::endl;
404415
return error_t::err_invalid_first_touch_configuration;
405416
}
406417
}
407418
else
408419
{
409420
hasSetupError = true;
421+
std::cerr << "Error: Invalid type for the first touch primitive, only support zero, copy, relu." << std::endl;
410422
return error_t::err_wrong_first_touch_primitive;
411423
}
412424
}
@@ -426,18 +438,32 @@ mini_jit::TensorOperation::error_t mini_jit::TensorOperation::setup(dtype_t dtyp
426438
release_assert(indexPrimBatch != -1, "Expected a valid index for the Batch dimension but found none.");
427439
release_assert(indexPrimK != -1, "Expected a valid index for the Batch dimension but found none.");
428440

429-
std::get<Brgemm>(main_kernel)
430-
.generate(dim_sizes[indexPrimM], dim_sizes[indexPrimN], dim_sizes[indexPrimK], dim_sizes[indexPrimBatch], 0, 0, 0,
431-
Brgemm::dtype_t::fp32);
441+
Brgemm::error_t error = std::get<Brgemm>(main_kernel)
442+
.generate(dim_sizes[indexPrimM], dim_sizes[indexPrimN], dim_sizes[indexPrimK], dim_sizes[indexPrimBatch],
443+
0, 0, 0, Brgemm::dtype_t::fp32);
444+
if (error != Brgemm::error_t::success)
445+
{
446+
hasSetupError = true;
447+
std::cerr << "Error: while generating the main brgemm: " << static_cast<uint32_t>(error) << std::endl;
448+
return error_t::err_invalid_main_configuration;
449+
}
432450
}
433451
else if (prim_main == prim_t::gemm)
434452
{
435453
indexPrimK = findMatch(dim_types, exec_types, dim_t::k, exec_t::prim);
436454

437455
release_assert(indexPrimK != -1, "Expected a valid index for the K dimension but found none.");
438456

439-
std::get<Brgemm>(main_kernel)
440-
.generate(dim_sizes[indexPrimM], dim_sizes[indexPrimN], dim_sizes[indexPrimK], 1, 0, 0, 0, Brgemm::dtype_t::fp32);
457+
Brgemm::error_t error =
458+
std::get<Brgemm>(main_kernel)
459+
.generate(dim_sizes[indexPrimM], dim_sizes[indexPrimN], dim_sizes[indexPrimK], 1, 0, 0, 0, Brgemm::dtype_t::fp32);
460+
461+
if (error != Brgemm::error_t::success)
462+
{
463+
hasSetupError = true;
464+
std::cerr << "Error: while generating the main gemm: " << static_cast<uint32_t>(error) << std::endl;
465+
return error_t::err_invalid_main_configuration;
466+
}
441467
}
442468
else
443469
{
@@ -454,12 +480,14 @@ mini_jit::TensorOperation::error_t mini_jit::TensorOperation::setup(dtype_t dtyp
454480
if (error != Unary::error_t::success)
455481
{
456482
hasSetupError = true;
483+
std::cerr << "Error: while generating the main unary: " << static_cast<uint32_t>(error) << std::endl;
457484
return error_t::err_invalid_main_configuration;
458485
}
459486
}
460487
else
461488
{
462489
hasSetupError = true;
490+
std::cerr << "Error: Invalid type for the main primitive, only support zero, copy, relu, gemm, brgemm." << std::endl;
463491
return error_t::err_wrong_main_primitive;
464492
}
465493
}
@@ -476,12 +504,14 @@ mini_jit::TensorOperation::error_t mini_jit::TensorOperation::setup(dtype_t dtyp
476504
if (error != Unary::error_t::success)
477505
{
478506
hasSetupError = true;
479-
return error_t::err_invalid_main_configuration;
507+
std::cerr << "Error: while generating the last touch unary: " << static_cast<uint32_t>(error) << std::endl;
508+
return error_t::err_invalid_last_touch_configuration;
480509
}
481510
}
482511
else
483512
{
484513
hasSetupError = true;
514+
std::cerr << "Error: Invalid type for the last touch primitive, only support zero, copy, relu." << std::endl;
485515
return error_t::err_wrong_last_touch_primitive;
486516
}
487517
}

src/test/TensorOperation.test.cpp

Lines changed: 80 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -624,97 +624,95 @@ TEST_CASE("Test tensor operation with first touch: unary (zero, relu, copy) & ma
624624
* =================================================================================================
625625
* =================================================================================================
626626
*/
627-
// TEST_CASE("Test tensor operation with outer loop with main kernel: unary (zero, relu, copy)", "[tensor_operation][unary][correctness]")
628-
// {
629-
// using namespace mini_jit;
630627

631-
// auto type = GENERATE(TensorOperation::prim_t::zero, TensorOperation::prim_t::relu, TensorOperation::prim_t::copy);
628+
TEST_CASE("Test tensor operation with outer loop with main kernel: unary (zero, relu, copy)", "[tensor_operation][unary][correctness]")
629+
{
630+
using namespace mini_jit;
632631

633-
// CAPTURE(type);
632+
auto type = GENERATE(TensorOperation::prim_t::zero, TensorOperation::prim_t::relu, TensorOperation::prim_t::copy);
634633

635-
// constexpr TensorOperation::dim_t dim_types[]{TensorOperation::dim_t::n, TensorOperation::dim_t::k, TensorOperation::dim_t::c,
636-
// TensorOperation::dim_t::m, TensorOperation::dim_t::k, TensorOperation::dim_t::m,
637-
// TensorOperation::dim_t::m, TensorOperation::dim_t::n};
638-
// constexpr TensorOperation::exec_t exec_types[]{TensorOperation::exec_t::seq, TensorOperation::exec_t::seq,
639-
// TensorOperation::exec_t::seq,
640-
// TensorOperation::exec_t::seq, TensorOperation::exec_t::seq,
641-
// TensorOperation::exec_t::seq, TensorOperation::exec_t::prim,
642-
// TensorOperation::exec_t::prim};
643-
// constexpr int64_t dim_sizes[]{2, 3, 5, 8, 13, 21, 64, 64};
644-
// constexpr int64_t strides_in0[]{64 * 64 * 21 * 13 * 8 * 5 * 3 * 2,
645-
// 64 * 64 * 21 * 13 * 8 * 5 * 1,
646-
// 64 * 64 * 21 * 13 * 8 * 5,
647-
// 64 * 64 * 21 * 13 * 8,
648-
// 64 * 64 * 21 * 13,
649-
// 64 * 64 * 21,
650-
// 64 * 64,
651-
// 1,
652-
// 64};
653-
// constexpr int64_t strides_in1[]{0, 0, 0, 0, 0, 0, 0, 0};
654-
// constexpr int64_t strides_out[]{64 * 64 * 21 * 13 * 8 * 5 * 3 * 2,
655-
// 64 * 64 * 21 * 13 * 8 * 5 * 1,
656-
// 64 * 64 * 21 * 13 * 8 * 5,
657-
// 64 * 64 * 21 * 13 * 8,
658-
// 64 * 64 * 21 * 1,
659-
// 64 * 64 * 21,
660-
// 64 * 64,
661-
// 1,
662-
// 64};
663-
664-
// GenerationTest test(64, 64, 64, 1, 64 * 64 * 21 * 13 * 8 * 5 * 3 * 2, 0, 64 * 64 * 21 * 13 * 8 * 5 * 3 * 2);
634+
CAPTURE(type);
665635

666-
// mini_jit::TensorOperation tensor_op;
667-
// TensorOperation::error_t err = tensor_op.setup(
668-
// TensorOperation::dtype_t::fp32, TensorOperation::prim_t::none, type, TensorOperation::prim_t::none, std::span{dim_types},
669-
// std::span{exec_types}, std::span{dim_sizes}, std::span{strides_in0}, std::span{strides_in1}, std::span{strides_out});
636+
constexpr TensorOperation::dim_t dim_types[]{TensorOperation::dim_t::n, TensorOperation::dim_t::k, TensorOperation::dim_t::c,
637+
TensorOperation::dim_t::m, TensorOperation::dim_t::k, TensorOperation::dim_t::m,
638+
TensorOperation::dim_t::m, TensorOperation::dim_t::n};
639+
constexpr TensorOperation::exec_t exec_types[]{TensorOperation::exec_t::seq, TensorOperation::exec_t::seq, TensorOperation::exec_t::seq,
640+
TensorOperation::exec_t::seq, TensorOperation::exec_t::seq, TensorOperation::exec_t::seq,
641+
TensorOperation::exec_t::prim, TensorOperation::exec_t::prim};
642+
constexpr int64_t dim_sizes[]{2, 3, 5, 8, 13, 21, 16, 16};
643+
constexpr int64_t strides_in0[]{16 * 16 * 1 * 13 * 8 * 1 * 3,
644+
0, // k-dim
645+
16 * 16 * 1 * 13 * 8,
646+
16 * 16 * 1 * 13,
647+
0, // k-dim
648+
16 * 16,
649+
1,
650+
16};
651+
constexpr int64_t strides_in1[]{0, 0, 0, 0, 0, 0, 0, 0};
652+
constexpr int64_t strides_out[]{16 * 16 * 1 * 13 * 8 * 1 * 3,
653+
0, // k-dim
654+
16 * 16 * 1 * 13 * 8,
655+
16 * 16 * 1 * 13,
656+
0, // k-dim
657+
16 * 16,
658+
1,
659+
16};
670660

671-
// REQUIRE(err == TensorOperation::error_t::success);
661+
GenerationTest test(16, 16, 16, 1, 16 * 16 * 21 * 13 * 8 * 5 * 3 * 2, 0, 16 * 16 * 21 * 13 * 8 * 5 * 3 * 2);
662+
test.SetUp(TestInfill::Random);
672663

673-
// tensor_op.execute(test.matrix_a.data(), nullptr, test.matrix_c.data());
664+
mini_jit::TensorOperation tensor_op;
665+
TensorOperation::error_t err = tensor_op.setup(
666+
TensorOperation::dtype_t::fp32, TensorOperation::prim_t::none, type, TensorOperation::prim_t::none, std::span{dim_types},
667+
std::span{exec_types}, std::span{dim_sizes}, std::span{strides_in0}, std::span{strides_in1}, std::span{strides_out});
674668

675-
// UnaryType test_type = UnaryType::None;
676-
// switch (type)
677-
// {
678-
// case TensorOperation::prim_t::zero:
679-
// test_type = UnaryType::Zero;
680-
// break;
681-
// case TensorOperation::prim_t::copy:
682-
// test_type = UnaryType::Identity;
683-
// break;
684-
// case TensorOperation::prim_t::relu:
685-
// test_type = UnaryType::ReLu;
686-
// break;
687-
// default:
688-
// FAIL("Could not parse the unary type!");
689-
// break;
690-
// }
669+
REQUIRE(err == TensorOperation::error_t::success);
691670

692-
// for (size_t i0 = 0; i0 < dim_sizes[0]; i0++)
693-
// {
694-
// for (size_t i1 = 0; i1 < dim_sizes[1]; i1++)
695-
// {
696-
// for (size_t i2 = 0; i2 < dim_sizes[2]; i2++)
697-
// {
698-
// for (size_t i3 = 0; i3 < dim_sizes[3]; i3++)
699-
// {
700-
// for (size_t i4 = 0; i4 < dim_sizes[4]; i4++)
701-
// {
702-
// for (size_t i5 = 0; i5 < dim_sizes[5]; i5++)
703-
// {
704-
// uint64_t offset_a = i0 * strides_in0[0] + i1 * strides_in0[1] + i2 * strides_in0[2] + i3 * strides_in0[3] +
705-
// i4 * strides_in0[4] + i5 * strides_in0[5];
706-
// uint64_t offset_c = i0 * strides_out[0] + i1 * strides_out[1] + i2 * strides_out[2] + i3 * strides_out[3] +
707-
// i4 * strides_out[4] + i5 * strides_out[5];
708-
// test.naive_unary_M_N(test.matrix_a.data() + offset_a, test.matrix_c_verify.data() + offset_c, 64, 64, false, test_type);
709-
// }
710-
// }
711-
// }
712-
// }
713-
// }
714-
// }
671+
tensor_op.execute(test.matrix_a.data(), nullptr, test.matrix_c.data());
715672

716-
// test.verify_matmul(test.matrix_c_verify.data(), test.matrix_c.data(), test.matrix_c.size());
717-
// }
673+
UnaryType test_type = UnaryType::None;
674+
switch (type)
675+
{
676+
case TensorOperation::prim_t::zero:
677+
test_type = UnaryType::Zero;
678+
break;
679+
case TensorOperation::prim_t::copy:
680+
test_type = UnaryType::Identity;
681+
break;
682+
case TensorOperation::prim_t::relu:
683+
test_type = UnaryType::ReLu;
684+
break;
685+
default:
686+
FAIL("Could not parse the unary type!");
687+
break;
688+
}
689+
690+
for (size_t i0 = 0; i0 < dim_sizes[0]; i0++)
691+
{
692+
for (size_t i1 = 0; i1 < dim_sizes[1]; i1++)
693+
{
694+
for (size_t i2 = 0; i2 < dim_sizes[2]; i2++)
695+
{
696+
for (size_t i3 = 0; i3 < dim_sizes[3]; i3++)
697+
{
698+
for (size_t i4 = 0; i4 < dim_sizes[4]; i4++)
699+
{
700+
for (size_t i5 = 0; i5 < dim_sizes[5]; i5++)
701+
{
702+
uint64_t offset_a = i0 * strides_in0[0] + i1 * strides_in0[1] + i2 * strides_in0[2] + i3 * strides_in0[3] +
703+
i4 * strides_in0[4] + i5 * strides_in0[5];
704+
uint64_t offset_c = i0 * strides_out[0] + i1 * strides_out[1] + i2 * strides_out[2] + i3 * strides_out[3] +
705+
i4 * strides_out[4] + i5 * strides_out[5];
706+
test.naive_unary_M_N(test.matrix_a.data() + offset_a, test.matrix_c_verify.data() + offset_c, 16, 16, false, test_type);
707+
}
708+
}
709+
}
710+
}
711+
}
712+
}
713+
714+
test.verify_matmul(test.matrix_c_verify.data(), test.matrix_c.data(), test.matrix_c.size());
715+
}
718716

719717
TEST_CASE("Test tensor operation with outer loop with main kernel: gemm", "[tensor_operation][gemm][correctness]")
720718
{

0 commit comments

Comments
 (0)