diff --git a/CHANGELOG.md b/CHANGELOG.md index 15dc78a0..2b87ad58 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,10 @@ ## Version 2.4 (in development) +### New features + +* [Pull Request 325](https://github.com/MassimoCimmino/pygfunction/pull/325) - Borefields and boreholes can now be concatenated using the `+` operator, e.g. using `new_field = field_1 + field_2`. + ### Other changes * [Issue 319](https://github.com/MassimoCimmino/pygfunction/issues/319) - Created `solvers` module. `Solver` classes are moved out of the `gfunction` module and into the new module. diff --git a/pygfunction/borefield.py b/pygfunction/borefield.py index 0a6b7dbc..3a92b888 100644 --- a/pygfunction/borefield.py +++ b/pygfunction/borefield.py @@ -107,6 +107,42 @@ def __ne__( check = not self == other_field return check + def __add__(self, + other_field: Union[Borehole, List[Borehole], Self]) -> Self: + """Add two borefields together""" + if not isinstance(other_field, (Borehole, list, self.__class__)): + raise TypeError( + f'Expected Borefield, list or Borehole input;' + f' got {other_field}' + ) + # List of boreholes + field = self.to_boreholes() + # Convert other_field to a list if it is a Borehole + if isinstance(other_field, Borehole): + other_field = [other_field] + # Convert borefield to a list if it is a Borefield + if isinstance(other_field, self.__class__): + other_field = other_field.to_boreholes() + return Borefield.from_boreholes(field + other_field) + + def __radd__(self, + other_field: Union[Borehole, List[Borehole], Self]) -> Self: + """Add two borefields together""" + if not isinstance(other_field, (Borehole, list, self.__class__)): + raise TypeError( + f'Expected Borefield, list or Borehole input;' + f' got {other_field}' + ) + # List of boreholes + field = self.to_boreholes() + # Convert other_field to a list if it is a Borehole + if isinstance(other_field, Borehole): + other_field = [other_field] + # Convert borefield to a list if it is a Borefield + if isinstance(other_field, self.__class__): + other_field = other_field.to_boreholes() + return Borefield.from_boreholes(other_field + field) + def evaluate_g_function( self, alpha: float, diff --git a/pygfunction/boreholes.py b/pygfunction/boreholes.py index 58c10013..330ed4b7 100644 --- a/pygfunction/boreholes.py +++ b/pygfunction/boreholes.py @@ -1,8 +1,10 @@ # -*- coding: utf-8 -*- +from typing import Union import warnings import numpy as np from scipy.spatial.distance import pdist +from typing_extensions import Self # for compatibility with Python <= 3.10 from .utilities import _initialize_figure, _format_axes, _format_axes_3d @@ -54,6 +56,56 @@ def __repr__(self): f' orientation={self.orientation})') return s + def __add__(self, other: Union[Self, list]): + """ + Adds two boreholes together to form a borefield + """ + if not isinstance(other, (self.__class__, list)): + # Check if other is a borefield and try the operation using + # other.__radd__ + try: + field = other.__radd__(self) + except: + # Invalid input + raise TypeError( + f'Expected Borefield, list or Borehole input;' + f' got {other}' + ) + elif isinstance(other, list): + # Create a borefield from the borehole and a list + from .borefield import Borefield + field = Borefield.from_boreholes([self] + other) + else: + # Create a borefield from the two boreholes + from .borefield import Borefield + field = Borefield.from_boreholes([self, other]) + return field + + def __radd__(self, other: Union[Self, list]): + """ + Adds two boreholes together to form a borefield + """ + if not isinstance(other, (self.__class__, list)): + # Check if other is a borefield and try the operation using + # other.__radd__ + try: + field = other.__add__(self) + except: + # Invalid input + raise TypeError( + f'Expected Borefield, list or Borehole input;' + f' got {other}' + ) + elif isinstance(other, list): + # Create a borefield from the borehole and a list + from .borefield import Borefield + field = Borefield.from_boreholes(other + [self]) + else: + # Create a borefield from the two boreholes + from .borefield import Borefield + field = Borefield.from_boreholes([other, self]) + return field + def distance(self, target): """ Evaluate the distance between the current borehole and a target diff --git a/tests/borefield_test.py b/tests/borefield_test.py index d95e5c86..440db758 100644 --- a/tests/borefield_test.py +++ b/tests/borefield_test.py @@ -32,6 +32,35 @@ def test_borefield_init(field, request): H, D, r_b, x, y, tilt=tilt, orientation=orientation) assert borefield == borefield_from_boreholes + +# Test Borefield.__add__ and Borefield.__radd__ +@pytest.mark.parametrize("field, other_field, field_list, other_field_list, field_borehole, other_field_borehole", [ + # Using Borefield objects + ('ten_boreholes_rectangular', 'two_boreholes_inclined', False, False, False, False), + # Using Borefield objects + ('single_borehole', 'two_boreholes_inclined', False, False, True, False), + ('ten_boreholes_rectangular', 'single_borehole_short', False, False, False, True), + # Using Borefield as lists + ('ten_boreholes_rectangular', 'two_boreholes_inclined', False, True, False, False), + ('ten_boreholes_rectangular', 'two_boreholes_inclined', True, False, False, False), + ]) +def test_borefield_add(field, other_field, field_list, other_field_list, field_borehole, other_field_borehole, request): + field = request.getfixturevalue(field) + other_field = request.getfixturevalue(other_field) + reference_field = gt.borefield.Borefield.from_boreholes( + field.to_boreholes() + other_field.to_boreholes() + ) + if field_list: + field = field.to_boreholes() + if other_field_list: + other_field = other_field.to_boreholes() + if field_borehole: + field = field[0] + if other_field_borehole: + other_field = other_field[0] + assert field + other_field == reference_field + + # Test borefield comparison using __eq__ @pytest.mark.parametrize("field, other_field, expected", [ # Fields that are equal @@ -53,6 +82,7 @@ def test_borefield_eq(field, other_field, expected, request): other_field = request.getfixturevalue(other_field) assert (borefield == other_field) == expected + # Test borefield comparison using __ne__ @pytest.mark.parametrize("field, other_field, expected", [ # Fields that are equal diff --git a/tests/boreholes_test.py b/tests/boreholes_test.py index 4a43a771..64a5ae94 100644 --- a/tests/boreholes_test.py +++ b/tests/boreholes_test.py @@ -33,6 +33,24 @@ def test_borehole_init(): ]) +# Test Borehole.__add__ +@pytest.mark.parametrize("borehole, other_borehole, borehole_list, other_borehole_list", [ + ('single_borehole', 'single_borehole_short', False, False), + ('single_borehole', 'single_borehole_short', True, False), + ('single_borehole', 'single_borehole_short', False, True), + ]) +def test_borehole_add(borehole, other_borehole, borehole_list, other_borehole_list, request): + borehole = request.getfixturevalue(borehole)[0] + other_borehole = request.getfixturevalue(other_borehole)[0] + field = gt.borefield.Borefield.from_boreholes( + [borehole, other_borehole]) + if borehole_list: + borehole = [borehole] + if other_borehole_list: + other_borehole = [other_borehole] + assert field == borehole + other_borehole + + # Test Borehole.distance @pytest.mark.parametrize("borehole1, borehole2", [ # Same borehole