-
Notifications
You must be signed in to change notification settings - Fork 70
Maximize compatability with Datatypes by returning NotImplemented if __add__, __mul__ ... fail #417
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 7 commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
03007f3
Return not implemented if to_linexpr() fails
FBumann 6c1cfdd
Add try except blocks to arithmetrics of Expression
FBumann bc313c3
Add try except blocks to __mul__ of Expression
FBumann 216d044
Add try except blocks arithmetrics of Variable
FBumann f7241db
add tests
FBumann 26d6180
Added comment to release_notes.rst
FBumann 615fb4d
Adjust random Generator
FBumann 084d112
add typehints and ignore types to test to satisfy mypy
FBumann e60a676
Removing uncommenting in release_notes.rst
FBumann 0993c22
Add extra test case
FBumann b113b9f
Remove unnessesary try catch
FBumann f64db9c
Wrapping everything into the try except in arithmetric operations
FBumann 8b24692
Remove try except from __pow__
FBumann da82360
Add more tests
FBumann 921c850
Add try except block to Expression.__div__
FBumann 35c9a03
Add more test cases
FBumann ed1b5e5
Remove some mypy errors
FBumann File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,114 @@ | ||
| import numpy as np | ||
| import pandas as pd | ||
| import pytest | ||
| import xarray as xr | ||
|
|
||
| from linopy import Model | ||
| from linopy.testing import assert_linequal | ||
|
|
||
|
|
||
| class SomeOtherDatatype: | ||
| """ | ||
| A class that is not a subclass of xarray.DataArray, but stores data in a compatible way. | ||
| It defines all necessary arithmetrics AND __array_ufunc__ to ensure that operations are | ||
| performed on the active_data. | ||
| """ | ||
|
|
||
| def __init__(self, data: xr.DataArray) -> None: | ||
| self.data1 = data | ||
| self.data2 = data.copy() | ||
| self.active = 1 | ||
|
|
||
| def activate(self, active: int) -> None: | ||
| self.active = active | ||
|
|
||
| @property | ||
| def active_data(self) -> xr.DataArray: | ||
| return self.data1 if self.active == 1 else self.data2 | ||
|
|
||
| def __add__(self, other): | ||
| return self.active_data + other | ||
|
|
||
| def __sub__(self, other): | ||
| return self.active_data - other | ||
|
|
||
| def __mul__(self, other): | ||
| return self.active_data * other | ||
|
|
||
| def __truediv__(self, other): | ||
| return self.active_data / other | ||
|
|
||
| def __radd__(self, other): | ||
| return other + self.active_data | ||
|
|
||
| def __rsub__(self, other): | ||
| return other - self.active_data | ||
|
|
||
| def __rmul__(self, other): | ||
| return other * self.active_data | ||
|
|
||
| def __rtruediv__(self, other): | ||
| return other / self.active_data | ||
|
|
||
| def __neg__(self): | ||
| return -self.active_data | ||
|
|
||
| def __pos__(self): | ||
| return +self.active_data | ||
|
|
||
| def __abs__(self): | ||
| return abs(self.active_data) | ||
|
|
||
| def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): | ||
| # Ensure we always use the active_data when interacting with numpy/xarray operations | ||
| new_inputs = [ | ||
| inp.active_data if isinstance(inp, SomeOtherDatatype) else inp | ||
| for inp in inputs | ||
| ] | ||
| return getattr(ufunc, method)(*new_inputs, **kwargs) | ||
|
|
||
|
|
||
| @pytest.fixture( | ||
| params=[ | ||
| (pd.RangeIndex(10, name="first"),), | ||
| ( | ||
| pd.Index(range(5), name="first"), | ||
| pd.Index(range(3), name="second"), | ||
| pd.Index(range(2), name="third"), | ||
| ), | ||
| ], | ||
| ids=["single_dim", "multi_dim"], | ||
| ) | ||
| def m(request) -> Model: | ||
| m = Model() | ||
| m.add_variables(coords=request.param, name="x") | ||
| m.add_variables(0, 10, name="z") | ||
| m.add_constraints(m.variables["x"] >= 0, name="c") | ||
| return m | ||
|
|
||
|
|
||
| def test_arithmetric_operations_variable(m: Model) -> None: | ||
| x = m.variables["x"] | ||
| rng = np.random.default_rng() | ||
| data = xr.DataArray(rng.random(x.shape), coords=x.coords) | ||
| other_datatype = SomeOtherDatatype(data.copy()) | ||
| assert_linequal(x + data, x + other_datatype) | ||
| assert_linequal(x - data, x - other_datatype) | ||
| assert_linequal(x * data, x * other_datatype) | ||
| assert_linequal(x / data, x / other_datatype) | ||
|
|
||
|
|
||
| def test_arithmetric_operations_con(m: Model) -> None: | ||
| c = m.constraints["c"] | ||
| x = m.variables["x"] | ||
| rng = np.random.default_rng() | ||
| data = xr.DataArray(rng.random(x.shape), coords=x.coords) | ||
| other_datatype = SomeOtherDatatype(data.copy()) | ||
| assert_linequal(c.lhs + data, c.lhs + other_datatype) | ||
| assert_linequal(c.lhs - data, c.lhs - other_datatype) | ||
| assert_linequal(c.lhs * data, c.lhs * other_datatype) | ||
| assert_linequal(c.lhs / data, c.lhs / other_datatype) | ||
| assert_linequal(c.rhs + data, c.rhs + other_datatype) | ||
| assert_linequal(c.rhs - data, c.rhs - other_datatype) | ||
| assert_linequal(c.rhs * data, c.rhs * other_datatype) | ||
| assert_linequal(c.rhs / data, c.rhs / other_datatype) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
..means this is commented out, please remove it for your note and the "upcoming version" header