-
Notifications
You must be signed in to change notification settings - Fork 2
check for Bottom in vmath rewrite #585
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,7 @@ | ||
| from typing import Any | ||
|
|
||
| import numpy as np | ||
| import pytest | ||
|
|
||
| from kirin.prelude import basic | ||
| from kirin.dialects import vmath | ||
|
|
@@ -22,6 +23,7 @@ def add_scalar_lhs(): | |
| return add_kernel(x=3.0, y=[3.0, 4, 5]) | ||
|
|
||
|
|
||
| @pytest.mark.xfail() | ||
| def test_add_scalar_lhs(): | ||
| # out = add_scalar_lhs() | ||
| add_scalar_lhs.print() | ||
|
|
@@ -31,9 +33,10 @@ def test_add_scalar_lhs(): | |
| assert np.allclose(np.asarray(res), np.array([6, 7, 8])) | ||
|
|
||
|
|
||
| @pytest.mark.xfail() | ||
|
||
| def test_typed_kernel_add(): | ||
| add_scalar_rhs_typed.print() | ||
| res = add_scalar_rhs_typed(IList([0, 1, 2]), 3.1) | ||
| res = add_scalar_rhs_typed(IList([0.0, 1.0, 2.0]), 3.1) | ||
| assert np.allclose(np.asarray(res), np.asarray([3.1, 4.1, 5.1])) | ||
|
|
||
|
|
||
|
|
@@ -52,6 +55,7 @@ def sub_scalar_rhs_typed(x: IList[float, Any], y: float): | |
| return x - y | ||
|
|
||
|
|
||
| @pytest.mark.xfail() | ||
|
||
| def test_sub_scalar_typed(): | ||
| res = sub_scalar_rhs_typed(IList([0, 1, 2]), 3.1) | ||
| assert np.allclose(np.asarray(res), np.asarray([-3.1, -2.1, -1.1])) | ||
|
|
@@ -62,11 +66,30 @@ def mult_scalar_lhs_typed(x: float, y: IList[float, Any]): | |
| return x * y | ||
|
|
||
|
|
||
| @basic.union([vmath])(typeinfer=True) | ||
| def mult_kernel(x, y): | ||
| return x * y | ||
|
|
||
|
|
||
| @basic.union([vmath])(typeinfer=True, aggressive=True) | ||
| def mult_scalar_lhs(): | ||
| return mult_kernel(x=3.0, y=[3.0, 4.0, 5.0]) | ||
|
|
||
|
|
||
| @pytest.mark.xfail() | ||
|
||
| def test_mult_scalar_typed(): | ||
| res = mult_scalar_lhs_typed(3, IList([0, 1, 2])) | ||
| assert np.allclose(np.asarray(res), np.asarray([0, 3, 6])) | ||
|
|
||
|
|
||
| @pytest.mark.xfail() | ||
|
||
| def test_mult_scalar_lhs(): | ||
| res = mult_scalar_lhs() | ||
| assert isinstance(res, IList) | ||
| assert res.type.vars[0].typ is float | ||
| assert np.allclose(np.asarray(res), np.array([9, 12, 15])) | ||
|
|
||
|
|
||
| @basic.union([vmath])(typeinfer=True) | ||
| def div_scalar_lhs_typed(x: float, y: IList[float, Any]): | ||
| return x / y | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
@pytest.mark.xfail()decorators should include areasonparameter to explain why these tests are expected to fail. Other tests in the codebase consistently use@mark.xfail(reason="...")to document the known issue.For example, based on the PR description, consider:
@pytest.mark.xfail(reason="type inference results in Bottom type for these cases")