@@ -275,6 +275,8 @@ mini_jit::TensorConfig mini_jit::EinsumTree::lower_node(const EinsumNode *node)
275275 std::vector<int64_t > dim_sizes;
276276 std::map<int64_t , size_t > id_map;
277277 uint32_t number_of_k = 0 ;
278+ dim_types.reserve (node->output_dim_ids .size ());
279+ dim_sizes.reserve (node->output_dim_ids .size ());
278280
279281 get_config_dim_types_and_sizes (node, id_map, dim_types, dim_sizes, number_of_k);
280282
@@ -509,9 +511,12 @@ mini_jit::EinsumTree::ErrorExecute mini_jit::EinsumTree::execute_node(const std:
509511 node->tensor = new float [node->get_size (dim_sizes)]();
510512 }
511513
512- std::cerr << node->to_string () << node->tensor_op .get_config ().to_string () << std::endl;
514+ if (node->tensor_op .getHasSetupError () == true )
515+ {
516+ return ErrorExecute::SetupHasError;
517+ }
518+
513519 node->tensor_op .execute (node->left ->tensor , nullptr , node->tensor );
514- std::cerr << " SUCCESS" << std::endl << std::endl;
515520 }
516521 else if (node->type == NodeType::Contraction)
517522 {
@@ -541,6 +546,11 @@ mini_jit::EinsumTree::ErrorExecute mini_jit::EinsumTree::execute_node(const std:
541546 node->tensor = new float [node->get_size (dim_sizes)]();
542547 }
543548
549+ if (node->tensor_op .getHasSetupError () == true )
550+ {
551+ return ErrorExecute::SetupHasError;
552+ }
553+
544554 node->tensor_op .execute (node->left ->tensor , node->right ->tensor , node->tensor );
545555 }
546556 else
@@ -680,7 +690,7 @@ void mini_jit::EinsumTree::conditional_swap(mini_jit::EinsumTree::EinsumNode *no
680690 }
681691}
682692
683- mini_jit::EinsumTree::ErrorParse mini_jit::EinsumTree::parse_tree ()
693+ mini_jit::EinsumTree::ErrorParse mini_jit::EinsumTree::parse_tree (bool build_operators )
684694{
685695 ErrorParse error = parse_tree_no_optimization (false );
686696
@@ -691,7 +701,10 @@ mini_jit::EinsumTree::ErrorParse mini_jit::EinsumTree::parse_tree()
691701
692702 optimize (root);
693703
694- error = generate_operators ();
704+ if (build_operators)
705+ {
706+ error = generate_operators ();
707+ }
695708
696709 return error;
697710}
@@ -704,63 +717,73 @@ mini_jit::EinsumTree::ErrorParse mini_jit::EinsumTree::generate_operators()
704717 return ErrorParse::InvalidRoot;
705718 }
706719
707- ErrorParse error = ErrorParse::None ;
708- std::vector<EinsumNode *> stack = {root};
720+ return generate_operator_node (root) ;
721+ }
709722
710- while (stack.size () > 0 )
723+ mini_jit::EinsumTree::ErrorParse mini_jit::EinsumTree::generate_operator_node (EinsumNode *node)
724+ {
725+ if (node->type == NodeType::Leaf)
726+ {
727+ return ErrorParse::None;
728+ }
729+ else if (node->type == NodeType::Transposition)
711730 {
712- EinsumNode *node = stack.back ();
713- stack.pop_back ();
731+ release_assert (node->left != nullptr , " Expected the left child of contraction to be a valid pointer." );
732+ release_assert (node->right == nullptr , " Expected the right child of contraction to be a nullptr." );
733+
734+ ErrorParse error = generate_operator_node (node->left );
714735
715- if (node-> type == NodeType::Leaf )
736+ if (error != ErrorParse::None )
716737 {
717- continue ;
738+ return error ;
718739 }
719- else if (node->type == NodeType::Transposition)
720- {
721- release_assert (node->left != nullptr , " Expected the left child of contraction to be a valid pointer." );
722- release_assert (node->right == nullptr , " Expected the right child of contraction to be a nullptr." );
723-
724- stack.push_back (node->left );
725740
726- TensorConfig config = lower_node (node);
727- TensorOperation::error_t error_setup = node->tensor_op .setup (config);
728- error = parse_setup_error (error_setup);
741+ TensorConfig config = lower_node (node);
742+ TensorOperation::error_t error_setup = node->tensor_op .setup (config);
743+ error = parse_setup_error (error_setup);
729744
730- if (error != ErrorParse::None)
731- {
732- return error;
733- }
745+ if (error != ErrorParse::None)
746+ {
747+ return error;
748+ }
734749
735750#ifdef SAVE_JITS_TO_FILE
736- node->tensor_op .write_kernel_to_file (node->name ());
751+ node->tensor_op .write_kernel_to_file (node->name ());
737752#endif // SAVE_JITS_TO_FILE
738- }
739- else if (node->type == NodeType::Contraction)
753+ }
754+ else if (node->type == NodeType::Contraction)
755+ {
756+ release_assert (node->left != nullptr , " Expected the left child of contraction to be a valid pointer." );
757+ release_assert (node->right != nullptr , " Expected the right child of contraction to be a valid pointer." );
758+
759+ ErrorParse error = generate_operator_node (node->left );
760+ if (error != ErrorParse::None)
740761 {
741- release_assert (node-> left != nullptr , " Expected the left child of contraction to be a valid pointer. " ) ;
742- release_assert (node-> right != nullptr , " Expected the right child of contraction to be a valid pointer. " );
762+ return error ;
763+ }
743764
744- stack.push_back (node->left );
745- stack.push_back (node->right );
765+ error = generate_operator_node (node->right );
766+ if (error != ErrorParse::None)
767+ {
768+ return error;
769+ }
746770
747- TensorConfig config = lower_node (node);
748- TensorOperation::error_t error_setup = node->tensor_op .setup (config);
749- error = parse_setup_error (error_setup);
771+ TensorConfig config = lower_node (node);
772+ TensorOperation::error_t error_setup = node->tensor_op .setup (config);
773+ error = parse_setup_error (error_setup);
750774
751- if (error != ErrorParse::None)
752- {
753- return error;
754- }
775+ if (error != ErrorParse::None)
776+ {
777+ return error;
778+ }
755779
756780#ifdef SAVE_JITS_TO_FILE
757- node->tensor_op .write_kernel_to_file (node->name ());
781+ node->tensor_op .write_kernel_to_file (node->name ());
758782#endif // SAVE_JITS_TO_FILE
759- }
760- else
761- {
762- release_assert (false , " Found unhandled einsum tree node type." );
763- }
783+ }
784+ else
785+ {
786+ release_assert (false , " Found unhandled einsum tree node type." );
764787 }
765788 return ErrorParse::None;
766789}
0 commit comments