Skip to content

Commit 6bb5c04

Browse files
committed
fix: interface
1 parent 5ee55df commit 6bb5c04

File tree

9 files changed

+131
-60
lines changed

9 files changed

+131
-60
lines changed

src/interface/Contraction.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ mlc::Error mlc::contraction(const Tensor &input0, const Tensor &input1, Tensor &
1313
const UnaryType firstTouch, const UnaryType lastTouch)
1414
{
1515
mini_jit::EinsumTree einsumTree(contraction);
16-
mini_jit::EinsumTree::ErrorParse errorParse = einsumTree.parse_tree();
16+
mini_jit::EinsumTree::ErrorParse errorParse = einsumTree.parse_tree(false);
1717
if (errorParse != mini_jit::EinsumTree::ErrorParse::None)
1818
{
1919
mlc::ErrorType type = internal::convertParseError(errorParse);
@@ -28,6 +28,12 @@ mlc::Error mlc::contraction(const Tensor &input0, const Tensor &input1, Tensor &
2828
std::vector<int64_t> sorted_dim_sizes;
2929
internal::get_sorted_dimensions_sizes(einsumTree.get_root(), {input0, input1}, sorted_dim_sizes);
3030
einsumTree.set_sorted_dim_sizes(sorted_dim_sizes);
31+
errorParse = einsumTree.generate_operators();
32+
if (errorParse != mini_jit::EinsumTree::ErrorParse::None)
33+
{
34+
mlc::ErrorType type = internal::convertParseError(errorParse);
35+
return {type, "Failed during operator generation for the given einsum tree."};
36+
}
3137

3238
mini_jit::TensorOperation op;
3339
mini_jit::TensorConfig config = einsumTree.lower_node(einsumTree.get_root());

src/interface/Einsum.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ mlc::Error mlc::einsum(const std::vector<Tensor *> &inputs, Tensor &output, cons
1616
mlc::EinsumOperation::EinsumOperation(const std::vector<std::reference_wrapper<const Tensor>> &inputs, Tensor &, const std::string &tree)
1717
: einsumTree(tree)
1818
{
19-
mini_jit::EinsumTree::ErrorParse errorParse = einsumTree.parse_tree();
19+
mini_jit::EinsumTree::ErrorParse errorParse = einsumTree.parse_tree(false);
2020
if (errorParse != mini_jit::EinsumTree::ErrorParse::None)
2121
{
2222
mlc::ErrorType type = internal::convertParseError(errorParse);
@@ -26,6 +26,12 @@ mlc::EinsumOperation::EinsumOperation(const std::vector<std::reference_wrapper<c
2626
std::vector<int64_t> sorted_dim_sizes;
2727
internal::get_sorted_dimensions_sizes<std::reference_wrapper<const Tensor>>(einsumTree.get_root(), inputs, sorted_dim_sizes);
2828
einsumTree.set_sorted_dim_sizes(sorted_dim_sizes);
29+
errorParse = einsumTree.generate_operators();
30+
if (errorParse != mini_jit::EinsumTree::ErrorParse::None)
31+
{
32+
mlc::ErrorType type = internal::convertParseError(errorParse);
33+
error = {type, "Failed to generate operators for the tree."};
34+
}
2935

3036
error = {mlc::ErrorType::None, "Success"};
3137
}

src/interface/Einsum.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ namespace mlc
2222
template <typename T> mlc::Error einsum(const std::vector<T> &inputs, mlc::Tensor &output, const std::string &tree)
2323
{
2424
mini_jit::EinsumTree einsumTree(tree);
25-
mini_jit::EinsumTree::ErrorParse errorParse = einsumTree.parse_tree();
25+
mini_jit::EinsumTree::ErrorParse errorParse = einsumTree.parse_tree(false);
2626
if (errorParse != mini_jit::EinsumTree::ErrorParse::None)
2727
{
2828
mlc::ErrorType type = convertParseError(errorParse);
@@ -32,6 +32,12 @@ namespace mlc
3232
std::vector<int64_t> sorted_dim_sizes;
3333
get_sorted_dimensions_sizes(einsumTree.get_root(), inputs, sorted_dim_sizes);
3434
einsumTree.set_sorted_dim_sizes(sorted_dim_sizes);
35+
errorParse = einsumTree.generate_operators();
36+
if (errorParse != mini_jit::EinsumTree::ErrorParse::None)
37+
{
38+
mlc::ErrorType type = convertParseError(errorParse);
39+
return {type, "Failed during operator generation for the given einsum tree."};
40+
}
3541

3642
std::vector<void *> tensors(inputs.size() + 1);
3743
for (size_t i = 0; i < inputs.size(); i++)

src/interface/TensorUtils.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,11 @@ namespace mlc
145145
*/
146146
constexpr mlc::ErrorType convertParseError(mini_jit::EinsumTree::ErrorParse error)
147147
{
148+
if (static_cast<int64_t>(error) > 100)
149+
{
150+
return static_cast<mlc::ErrorType>(static_cast<int64_t>(error));
151+
}
152+
148153
switch (error)
149154
{
150155
case mini_jit::EinsumTree::ErrorParse::None:
@@ -176,11 +181,6 @@ namespace mlc
176181
*/
177182
constexpr mlc::ErrorType convertErrorExecute(mini_jit::EinsumTree::ErrorExecute error)
178183
{
179-
if (static_cast<int64_t>(error) > 100)
180-
{
181-
return static_cast<mlc::ErrorType>(static_cast<int64_t>(error));
182-
}
183-
184184
switch (error)
185185
{
186186
case mini_jit::EinsumTree::ErrorExecute::None:

src/main/EinsumTree.cpp

Lines changed: 67 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,8 @@ mini_jit::TensorConfig mini_jit::EinsumTree::lower_node(const EinsumNode *node)
275275
std::vector<int64_t> dim_sizes;
276276
std::map<int64_t, size_t> id_map;
277277
uint32_t number_of_k = 0;
278+
dim_types.reserve(node->output_dim_ids.size());
279+
dim_sizes.reserve(node->output_dim_ids.size());
278280

279281
get_config_dim_types_and_sizes(node, id_map, dim_types, dim_sizes, number_of_k);
280282

@@ -509,9 +511,12 @@ mini_jit::EinsumTree::ErrorExecute mini_jit::EinsumTree::execute_node(const std:
509511
node->tensor = new float[node->get_size(dim_sizes)]();
510512
}
511513

512-
std::cerr << node->to_string() << node->tensor_op.get_config().to_string() << std::endl;
514+
if (node->tensor_op.getHasSetupError() == true)
515+
{
516+
return ErrorExecute::SetupHasError;
517+
}
518+
513519
node->tensor_op.execute(node->left->tensor, nullptr, node->tensor);
514-
std::cerr << "SUCCESS" << std::endl << std::endl;
515520
}
516521
else if (node->type == NodeType::Contraction)
517522
{
@@ -541,6 +546,11 @@ mini_jit::EinsumTree::ErrorExecute mini_jit::EinsumTree::execute_node(const std:
541546
node->tensor = new float[node->get_size(dim_sizes)]();
542547
}
543548

549+
if (node->tensor_op.getHasSetupError() == true)
550+
{
551+
return ErrorExecute::SetupHasError;
552+
}
553+
544554
node->tensor_op.execute(node->left->tensor, node->right->tensor, node->tensor);
545555
}
546556
else
@@ -680,7 +690,7 @@ void mini_jit::EinsumTree::conditional_swap(mini_jit::EinsumTree::EinsumNode *no
680690
}
681691
}
682692

683-
mini_jit::EinsumTree::ErrorParse mini_jit::EinsumTree::parse_tree()
693+
mini_jit::EinsumTree::ErrorParse mini_jit::EinsumTree::parse_tree(bool build_operators)
684694
{
685695
ErrorParse error = parse_tree_no_optimization(false);
686696

@@ -691,7 +701,10 @@ mini_jit::EinsumTree::ErrorParse mini_jit::EinsumTree::parse_tree()
691701

692702
optimize(root);
693703

694-
error = generate_operators();
704+
if (build_operators)
705+
{
706+
error = generate_operators();
707+
}
695708

696709
return error;
697710
}
@@ -704,63 +717,73 @@ mini_jit::EinsumTree::ErrorParse mini_jit::EinsumTree::generate_operators()
704717
return ErrorParse::InvalidRoot;
705718
}
706719

707-
ErrorParse error = ErrorParse::None;
708-
std::vector<EinsumNode *> stack = {root};
720+
return generate_operator_node(root);
721+
}
709722

710-
while (stack.size() > 0)
723+
mini_jit::EinsumTree::ErrorParse mini_jit::EinsumTree::generate_operator_node(EinsumNode *node)
724+
{
725+
if (node->type == NodeType::Leaf)
726+
{
727+
return ErrorParse::None;
728+
}
729+
else if (node->type == NodeType::Transposition)
711730
{
712-
EinsumNode *node = stack.back();
713-
stack.pop_back();
731+
release_assert(node->left != nullptr, "Expected the left child of contraction to be a valid pointer.");
732+
release_assert(node->right == nullptr, "Expected the right child of contraction to be a nullptr.");
733+
734+
ErrorParse error = generate_operator_node(node->left);
714735

715-
if (node->type == NodeType::Leaf)
736+
if (error != ErrorParse::None)
716737
{
717-
continue;
738+
return error;
718739
}
719-
else if (node->type == NodeType::Transposition)
720-
{
721-
release_assert(node->left != nullptr, "Expected the left child of contraction to be a valid pointer.");
722-
release_assert(node->right == nullptr, "Expected the right child of contraction to be a nullptr.");
723-
724-
stack.push_back(node->left);
725740

726-
TensorConfig config = lower_node(node);
727-
TensorOperation::error_t error_setup = node->tensor_op.setup(config);
728-
error = parse_setup_error(error_setup);
741+
TensorConfig config = lower_node(node);
742+
TensorOperation::error_t error_setup = node->tensor_op.setup(config);
743+
error = parse_setup_error(error_setup);
729744

730-
if (error != ErrorParse::None)
731-
{
732-
return error;
733-
}
745+
if (error != ErrorParse::None)
746+
{
747+
return error;
748+
}
734749

735750
#ifdef SAVE_JITS_TO_FILE
736-
node->tensor_op.write_kernel_to_file(node->name());
751+
node->tensor_op.write_kernel_to_file(node->name());
737752
#endif // SAVE_JITS_TO_FILE
738-
}
739-
else if (node->type == NodeType::Contraction)
753+
}
754+
else if (node->type == NodeType::Contraction)
755+
{
756+
release_assert(node->left != nullptr, "Expected the left child of contraction to be a valid pointer.");
757+
release_assert(node->right != nullptr, "Expected the right child of contraction to be a valid pointer.");
758+
759+
ErrorParse error = generate_operator_node(node->left);
760+
if (error != ErrorParse::None)
740761
{
741-
release_assert(node->left != nullptr, "Expected the left child of contraction to be a valid pointer.");
742-
release_assert(node->right != nullptr, "Expected the right child of contraction to be a valid pointer.");
762+
return error;
763+
}
743764

744-
stack.push_back(node->left);
745-
stack.push_back(node->right);
765+
error = generate_operator_node(node->right);
766+
if (error != ErrorParse::None)
767+
{
768+
return error;
769+
}
746770

747-
TensorConfig config = lower_node(node);
748-
TensorOperation::error_t error_setup = node->tensor_op.setup(config);
749-
error = parse_setup_error(error_setup);
771+
TensorConfig config = lower_node(node);
772+
TensorOperation::error_t error_setup = node->tensor_op.setup(config);
773+
error = parse_setup_error(error_setup);
750774

751-
if (error != ErrorParse::None)
752-
{
753-
return error;
754-
}
775+
if (error != ErrorParse::None)
776+
{
777+
return error;
778+
}
755779

756780
#ifdef SAVE_JITS_TO_FILE
757-
node->tensor_op.write_kernel_to_file(node->name());
781+
node->tensor_op.write_kernel_to_file(node->name());
758782
#endif // SAVE_JITS_TO_FILE
759-
}
760-
else
761-
{
762-
release_assert(false, "Found unhandled einsum tree node type.");
763-
}
783+
}
784+
else
785+
{
786+
release_assert(false, "Found unhandled einsum tree node type.");
764787
}
765788
return ErrorParse::None;
766789
}

src/main/EinsumTree.h

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ namespace mini_jit
4949
InvalidRoot = 1,
5050
NotEnoughInputTensors = 2,
5151
TooManyInputTensors = 3,
52+
SetupHasError = 4,
5253
NullPtrAsInputTensor = 5,
5354
};
5455

@@ -234,11 +235,12 @@ namespace mini_jit
234235
int32_t findMDim(EinsumNode *Node);
235236

236237
/**
237-
* @brief Generates the operator to the parsed einsum tree.
238+
* @brief Generates the operator of the given node recursively
238239
*
239-
* @return ErrorParse indicating the result of the parsing operation.
240+
* @param node The node to generate the operator for.
241+
* @return ErrorParse The error during creation of the operator.
240242
*/
241-
ErrorParse generate_operators();
243+
ErrorParse generate_operator_node(EinsumNode *node);
242244

243245
/**
244246
* @brief Swap the strides so that the strides position match the out Ids with the current stride location based on the inIds.
@@ -275,9 +277,17 @@ namespace mini_jit
275277
/**
276278
* Parses the einsum tree string, builds the tree structure and optimizes the tree.
277279
*
280+
* @param build_operators indicates if the operators should be generate with the parse.
278281
* @return ErrorParse indicating the result of the parsing operation.
279282
*/
280-
ErrorParse parse_tree();
283+
ErrorParse parse_tree(bool build_operators = true);
284+
285+
/**
286+
* @brief Generates the operator to the parsed einsum tree.
287+
*
288+
* @return ErrorParse indicating the result of the parsing operation.
289+
*/
290+
ErrorParse generate_operators();
281291

282292
/**
283293
* Returns the root node of the EinsumTree.

src/main/TensorOperation.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,7 @@ mini_jit::TensorOperation::error_t mini_jit::TensorOperation::setup_no_optimizat
403403
std::span<const int64_t> strides_in0, std::span<const int64_t> strides_in1, std::span<const int64_t> strides_out)
404404
{
405405
// Reset to defaults
406-
hasSetupError = false;
406+
hasSetupError = true;
407407
isParallel = false;
408408
isTranspose = false;
409409
indexPrimBatch = -1;
@@ -665,6 +665,7 @@ mini_jit::TensorOperation::error_t mini_jit::TensorOperation::setup_no_optimizat
665665
TensorOperation::strides_in1 = strides_in1;
666666
TensorOperation::strides_out = strides_out;
667667

668+
hasSetupError = false;
668669
return error_t::success;
669670
}
670671

@@ -833,4 +834,9 @@ void mini_jit::TensorOperation::write_kernel_to_file(std::string path_no_extensi
833834
{
834835
std::get<Unary>(last_touch).write_kernel_to_file(std::format("{}_first_touch.bin", path_no_extension).c_str());
835836
}
836-
}
837+
}
838+
839+
bool mini_jit::TensorOperation::getHasSetupError()
840+
{
841+
return hasSetupError;
842+
}

src/main/TensorOperation.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ namespace mini_jit
7373

7474
bool isTranspose = false; // default is no transpose
7575

76-
bool hasSetupError = false;
76+
bool hasSetupError = true; // default is true to indicate no setup was executed
7777

7878
/**
7979
* @brief Validates that exactly one m primitive dimension and one n primitive dimension exists.
@@ -249,6 +249,14 @@ namespace mini_jit
249249
* @param path The file to write the kernel to without extension.
250250
*/
251251
void write_kernel_to_file(std::string path_no_extension) const;
252+
253+
/**
254+
* @brief Indicates if the setup resulted in an error or was not initalized.
255+
*
256+
* @return true Setup has error.
257+
* @return false The setup was successful.
258+
*/
259+
bool getHasSetupError();
252260
};
253261
}; // namespace mini_jit
254262

src/test/interface/TensorUtils.test.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@ TEST_CASE("Test interface tensor utils get_sorted_dimensions_sizes", "[tensor][c
3737
mlc::Tensor tensor3(data2, shape3);
3838

3939
mini_jit::EinsumTree tree("[0,1],[1,2]->[0,2]");
40-
tree.parse_tree();
40+
mini_jit::EinsumTree::ErrorParse error = tree.parse_tree(false);
41+
42+
REQUIRE(error == mini_jit::EinsumTree::ErrorParse::None);
4143

4244
std::vector<int64_t> sorted_dimensions_sizes;
4345
mlc::internal::get_sorted_dimensions_sizes(tree.get_root(), {tensor1, tensor2}, sorted_dimensions_sizes);
@@ -50,4 +52,8 @@ TEST_CASE("Test interface tensor utils get_sorted_dimensions_sizes", "[tensor][c
5052
CAPTURE(i);
5153
REQUIRE(expected[i] == sorted_dimensions_sizes[i]);
5254
}
55+
56+
delete[] data1;
57+
delete[] data2;
58+
delete[] data3;
5359
}

0 commit comments

Comments
 (0)