Skip to content

Commit 9bbd375

Browse files
MillaFleursKD2YCUangeloskath
authored
Fix return value in einsum_path for simple contractions (#3232)
Co-authored-by: KD2YCU <me@kd2ycu.com> Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
1 parent a25399c commit 9bbd375

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

mlx/einsum.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -800,7 +800,7 @@ std::pair<std::vector<PathNode>, PathInfo> einsum_path_helper(
800800
max_size = std::max(max_size, term_size(in, dim_map));
801801
}
802802

803-
PathInfo path_info;
803+
PathInfo path_info{};
804804

805805
// Get the full naive cost
806806
std::tie(path_info.naive_cost, path_info.naive_scaling) =
@@ -813,6 +813,8 @@ std::pair<std::vector<PathNode>, PathInfo> einsum_path_helper(
813813
std::iota(positions.begin(), positions.end(), 0);
814814
path.emplace_back(
815815
std::move(inputs), std::move(output), std::move(positions));
816+
path_info.optimized_cost = path_info.naive_cost;
817+
path_info.optimized_scaling = path_info.naive_scaling;
816818
} else {
817819
std::tie(path, path_info.optimized_cost, path_info.optimized_scaling) =
818820
greedy_path(inputs, output, dim_map, path_info.naive_cost, max_size);

0 commit comments

Comments
 (0)