@@ -304,14 +304,17 @@ mini_jit::TensorConfig mini_jit::EinsumTree::lower_node(const EinsumNode *node)
304304 " Expected input and output to have same dimensions for copy operation." );
305305 release_assert (node->get_size (dim_sizes) == node->left ->get_size (dim_sizes), " Expected the accumulated size to be the same." );
306306
307+ std::vector<int64_t > stridesIn0 = compute_strides (node->left ->output_dim_ids );
308+ stridesIn0 = swap_strides_id_based (stridesIn0, node->left ->output_dim_ids , node->output_dim_ids );
309+
307310 TensorConfig config{
308311 TensorConfig::prim_t ::none, // first_touch
309312 TensorConfig::prim_t ::copy, // main
310313 TensorConfig::prim_t ::none, // last touch
311314 std::vector<TensorConfig::dim_t >(node->output_dim_ids .size (), TensorConfig::dim_t ::c), // dim_types
312315 std::vector<TensorConfig::exec_t >(node->output_dim_ids .size (), TensorConfig::exec_t ::seq), // exec_types
313316 get_output_dims (node->output_dim_ids ), // dim_sizes
314- compute_strides (node-> left -> output_dim_ids ), // strides_in0
317+ stridesIn0, // strides_in0
315318 std::vector<int64_t >(node->output_dim_ids .size (), 0 ), // strides_in1 (not used for transposition)
316319 compute_strides (node->output_dim_ids ), // strides_out
317320 TensorConfig::dtype_t ::fp32 // dtype_t
@@ -506,7 +509,9 @@ mini_jit::EinsumTree::ErrorExecute mini_jit::EinsumTree::execute_node(const std:
506509 node->tensor = new float [node->get_size (dim_sizes)]();
507510 }
508511
512+ std::cerr << node->to_string () << node->tensor_op .get_config ().to_string () << std::endl;
509513 node->tensor_op .execute (node->left ->tensor , nullptr , node->tensor );
514+ std::cerr << " SUCCESS" << std::endl << std::endl;
510515 }
511516 else if (node->type == NodeType::Contraction)
512517 {
@@ -570,6 +575,26 @@ std::vector<int64_t> mini_jit::EinsumTree::compute_strides(const std::vector<int
570575 return strides;
571576}
572577
578+ std::vector<int64_t > mini_jit::EinsumTree::swap_strides_id_based (const std::vector<int64_t > &strides, const std::vector<int64_t > &inIds,
579+ const std::vector<int64_t > &outIds)
580+ {
581+ release_assert (inIds.size () == outIds.size (), " Expected inIds to have the same size as the outIds." );
582+ release_assert (inIds.size () == strides.size (), " Expected the inIds to have the same size as the outIds." );
583+
584+ std::vector<int64_t > outStrides (strides.size ());
585+
586+ for (size_t i = 0 ; i < inIds.size (); i++)
587+ {
588+ auto outPtr = std::find (outIds.begin (), outIds.end (), inIds[i]);
589+ release_assert (outPtr != outIds.end (), " Expected to have the same elements as the inIds." );
590+
591+ auto outIndex = std::distance (outIds.begin (), outPtr);
592+ outStrides[outIndex] = strides[i];
593+ }
594+
595+ return outStrides;
596+ }
597+
573598std::vector<int64_t > mini_jit::EinsumTree::get_output_dims (const std::vector<int64_t > &dim_ids)
574599{
575600 std::vector<int64_t > dims (dim_ids.size ());
0 commit comments