Skip to content

Commit 0cfeeb6

Browse files
Maalvi14awni
andauthored
Einsum error msg improvement (#2690)
* Improved error message for Einsum * Modifications via pre-commit * format * nits --------- Co-authored-by: Awni Hannun <[email protected]>
1 parent 8f8af61 commit 0cfeeb6

File tree

1 file changed

+15
-3
lines changed

1 file changed

+15
-3
lines changed

mlx/einsum.cpp

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -671,7 +671,8 @@ std::pair<std::vector<PathNode>, PathInfo> einsum_path_helper(
671671
}
672672
int max_ellipsis_length = 0;
673673
auto check_letters_and_expand_ellipsis = [&](auto& subscript,
674-
const array* operand) {
674+
const array* operand,
675+
int operand_idx) {
675676
bool have_ellipsis = false;
676677
int cnt_before = 0, cnt_after = 0;
677678
for (int i = 0; i < subscript.size(); i++) {
@@ -708,10 +709,21 @@ std::pair<std::vector<PathNode>, PathInfo> einsum_path_helper(
708709
int ellipsis_length;
709710
if (operand != nullptr) {
710711
ellipsis_length = operand->ndim() - cnt_before - cnt_after;
712+
if (ellipsis_length < 0) {
713+
std::ostringstream msg;
714+
msg << "[" << fn_name << "] Operand " << operand_idx << " with shape "
715+
<< operand->shape()
716+
<< " has insufficient dimensions for subscript '" << subscript
717+
<< "'. The ellipsis requires at least "
718+
<< (cnt_before + cnt_after) << " dimensions but the operand has "
719+
<< operand->ndim() << " dimensions.";
720+
throw std::invalid_argument(msg.str());
721+
}
711722
max_ellipsis_length = std::max(ellipsis_length, max_ellipsis_length);
712723
} else {
713724
ellipsis_length = max_ellipsis_length;
714725
}
726+
715727
subscript.replace(
716728
subscript.begin() + cnt_before,
717729
subscript.begin() + cnt_before + 3,
@@ -721,9 +733,9 @@ std::pair<std::vector<PathNode>, PathInfo> einsum_path_helper(
721733
};
722734

723735
for (int i = 0; i < operands.size(); i++) {
724-
check_letters_and_expand_ellipsis(in_subscripts[i], &operands[i]);
736+
check_letters_and_expand_ellipsis(in_subscripts[i], &operands[i], i);
725737
}
726-
check_letters_and_expand_ellipsis(out_subscript, nullptr);
738+
check_letters_and_expand_ellipsis(out_subscript, nullptr, -1);
727739

728740
CharSet out_set(out_subscript.begin(), out_subscript.end());
729741
if (out_set.size() != out_subscript.size()) {

0 commit comments

Comments
 (0)