Skip to content

Commit 5ee55df

Browse files
committed
fix: tranpose stride issue
1 parent 20401dc commit 5ee55df

File tree

3 files changed

+37
-2
lines changed

3 files changed

+37
-2
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ jobs:
4545
# Execute tests defined by the CMake configuration.
4646
run: |
4747
ctest -j ${{env.parallel_processes}} -C ${{matrix.build_type}} --test-dir neon --output-on-failure
48-
ctest -j ${{env.parallel_processes}} -C ${{matrix.build_type}} --output-on-failure -E "^Test einsum tree optimize and execute first example"
48+
ctest -j ${{env.parallel_processes}} -C ${{matrix.build_type}} --output-on-failure
4949
5050
- name: Test + Valgrind
5151
working-directory: ${{github.workspace}}/build

src/main/EinsumTree.cpp

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,14 +304,17 @@ mini_jit::TensorConfig mini_jit::EinsumTree::lower_node(const EinsumNode *node)
304304
"Expected input and output to have same dimensions for copy operation.");
305305
release_assert(node->get_size(dim_sizes) == node->left->get_size(dim_sizes), "Expected the accumulated size to be the same.");
306306

307+
std::vector<int64_t> stridesIn0 = compute_strides(node->left->output_dim_ids);
308+
stridesIn0 = swap_strides_id_based(stridesIn0, node->left->output_dim_ids, node->output_dim_ids);
309+
307310
TensorConfig config{
308311
TensorConfig::prim_t::none, // first_touch
309312
TensorConfig::prim_t::copy, // main
310313
TensorConfig::prim_t::none, // last touch
311314
std::vector<TensorConfig::dim_t>(node->output_dim_ids.size(), TensorConfig::dim_t::c), // dim_types
312315
std::vector<TensorConfig::exec_t>(node->output_dim_ids.size(), TensorConfig::exec_t::seq), // exec_types
313316
get_output_dims(node->output_dim_ids), // dim_sizes
314-
compute_strides(node->left->output_dim_ids), // strides_in0
317+
stridesIn0, // strides_in0
315318
std::vector<int64_t>(node->output_dim_ids.size(), 0), // strides_in1 (not used for transposition)
316319
compute_strides(node->output_dim_ids), // strides_out
317320
TensorConfig::dtype_t::fp32 // dtype_t
@@ -506,7 +509,9 @@ mini_jit::EinsumTree::ErrorExecute mini_jit::EinsumTree::execute_node(const std:
506509
node->tensor = new float[node->get_size(dim_sizes)]();
507510
}
508511

512+
std::cerr << node->to_string() << node->tensor_op.get_config().to_string() << std::endl;
509513
node->tensor_op.execute(node->left->tensor, nullptr, node->tensor);
514+
std::cerr << "SUCCESS" << std::endl << std::endl;
510515
}
511516
else if (node->type == NodeType::Contraction)
512517
{
@@ -570,6 +575,26 @@ std::vector<int64_t> mini_jit::EinsumTree::compute_strides(const std::vector<int
570575
return strides;
571576
}
572577

578+
std::vector<int64_t> mini_jit::EinsumTree::swap_strides_id_based(const std::vector<int64_t> &strides, const std::vector<int64_t> &inIds,
579+
const std::vector<int64_t> &outIds)
580+
{
581+
release_assert(inIds.size() == outIds.size(), "Expected inIds to have the same size as the outIds.");
582+
release_assert(inIds.size() == strides.size(), "Expected the inIds to have the same size as the outIds.");
583+
584+
std::vector<int64_t> outStrides(strides.size());
585+
586+
for (size_t i = 0; i < inIds.size(); i++)
587+
{
588+
auto outPtr = std::find(outIds.begin(), outIds.end(), inIds[i]);
589+
release_assert(outPtr != outIds.end(), "Expected to have the same elements as the inIds.");
590+
591+
auto outIndex = std::distance(outIds.begin(), outPtr);
592+
outStrides[outIndex] = strides[i];
593+
}
594+
595+
return outStrides;
596+
}
597+
573598
std::vector<int64_t> mini_jit::EinsumTree::get_output_dims(const std::vector<int64_t> &dim_ids)
574599
{
575600
std::vector<int64_t> dims(dim_ids.size());

src/main/EinsumTree.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,16 @@ namespace mini_jit
240240
*/
241241
ErrorParse generate_operators();
242242

243+
/**
244+
* @brief Swap the strides so that the strides position match the out Ids with the current stride location based on the inIds.
245+
*
246+
* @param strides The strides to adjust.
247+
* @param inIds The ids the strides got calculated with.
248+
* @param outIds The order of the strides that is expected for the strides.
249+
*/
250+
std::vector<int64_t> swap_strides_id_based(const std::vector<int64_t> &strides, const std::vector<int64_t> &inIds,
251+
const std::vector<int64_t> &outIds);
252+
243253
public:
244254
EinsumTree(const std::string &tree_str);
245255
EinsumTree(const std::string &tree_str, const std::vector<int64_t> &sorted_dim_sizes);

0 commit comments

Comments
 (0)