diff --git a/src/tir/analysis/verify_well_formed.cc b/src/tir/analysis/verify_well_formed.cc index 2c8740f4f0ee..c10931d1bd10 100644 --- a/src/tir/analysis/verify_well_formed.cc +++ b/src/tir/analysis/verify_well_formed.cc @@ -248,20 +248,25 @@ class UndefinedVarVerifier : public Verifier { bool redefine_is_allowed = redefine_allowed_within_function_.count(var); { auto it = currently_defined_.find(var); - Verify(it == currently_defined_.end() || redefine_is_allowed) - << "ValueError: " - << "TIR is ill-formed, " - << "due to multiple nested definitions of variable " << var - << ". It was first defined at " << it->second << ", and was re-defined at " << path; + auto verify = Verify(it == currently_defined_.end() || redefine_is_allowed); + verify << "ValueError: " + << "TIR is ill-formed, " + << "due to multiple nested definitions of variable " << var << "."; + if (it != currently_defined_.end()) { + verify << " It was first defined at " << it->second << ", and was re-defined at " << path; + } } { auto it = previously_defined_.find(var); - Verify(it == previously_defined_.end() || redefine_is_allowed) - << "ValueError: " - << "TIR is ill-formed, " - << "due to multiple definitions of variable " << var << ". It was first defined at " - << it->second << ", and was later re-defined at " << path; + auto verify = Verify(it == previously_defined_.end() || redefine_is_allowed); + verify << "ValueError: " + << "TIR is ill-formed, " + << "due to multiple definitions of variable " << var << "."; + if (it != previously_defined_.end()) { + verify << " It was first defined at " << it->second << ", and was later re-defined at " + << path; + } } currently_defined_.insert({var, path}); diff --git a/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py b/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py index cddc9131f30f..f6e1d2eade24 100644 --- a/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py +++ b/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py @@ -345,5 +345,88 @@ def func(): tvm.tir.analysis.verify_well_formed(mod) +def test_error_message_without_previous_definition_location(): + """Test case 1: Error message without 'It was first defined at' + + This tests the scenario where it == end(), so the error message should contain + 'TIR is ill-formed, due to multiple definitions of variable' but should NOT + contain 'It was first defined at' since the iterator is invalid. + """ + + @T.prim_func(check_well_formed=False) + def func(): + x = T.int32() + + with T.LetStmt(42, var=x): + T.evaluate(x) + + with T.LetStmt(99, var=x): # This should trigger the error + T.evaluate(x) + + with pytest.raises(ValueError) as exc_info: + tvm.tir.analysis.verify_well_formed(func, assert_mode=True) + + error_msg = str(exc_info.value) + + assert "TIR is ill-formed" in error_msg + assert "multiple definitions of variable" in error_msg + + +def test_error_message_with_previous_definition_location(): + """Test case 2: Error message with 'It was first defined at' + + This tests the scenario where it != end(), so the error message should contain + both 'TIR is ill-formed, due to multiple definitions of variable' and should also + contain 'It was first defined at' with the location information. + """ + + @T.prim_func(check_well_formed=False) + def func(): + x = T.int32() + + with T.LetStmt(42, var=x): + with T.LetStmt(99, var=x): # This should trigger the error + T.evaluate(x) + + with pytest.raises(ValueError) as exc_info: + tvm.tir.analysis.verify_well_formed(func, assert_mode=True) + + error_msg = str(exc_info.value) + + assert "TIR is ill-formed" in error_msg + assert "multiple nested definitions of variable" in error_msg + + # should contains location information since it != end() + assert "It was first defined at" in error_msg + assert "was re-defined at" in error_msg + + +def test_sequential_redefinition_with_location(): + """Test case 2b: Sequential redefinition that includes location info + + This tests the previously_defined_ path where it != end() + """ + + @T.prim_func(check_well_formed=False) + def func(): + x = T.int32() + + with T.LetStmt(1, var=x): + T.evaluate(x) + + with T.LetStmt(2, var=x): # This should trigger the error + T.evaluate(x) + + with pytest.raises(ValueError) as exc_info: + tvm.tir.analysis.verify_well_formed(func, assert_mode=True) + + error_msg = str(exc_info.value) + + assert "TIR is ill-formed" in error_msg + assert "multiple definitions of variable" in error_msg + assert "It was first defined at" in error_msg + assert "later re-defined at" in error_msg + + if __name__ == "__main__": tvm.testing.main()