Skip to content

Commit 6215ce1

Browse files
committed
Merge pull request #215 from guziy/shiftdata_issue2013
change shiftdata so 1 and 2 points work, add tests.. (#213)
2 parents 6346916 + fcb7a4c commit 6215ce1

File tree

2 files changed

+75
-13
lines changed

2 files changed

+75
-13
lines changed

lib/mpl_toolkits/basemap/__init__.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4777,10 +4777,15 @@ def shiftdata(self,lonsin,datain=None,lon_0=None):
47774777
lonsin1 = lonsin[0,:]
47784778
lonsin1 = np.where(lonsin1 > lon_0+180, lonsin1-360 ,lonsin1)
47794779
lonsin1 = np.where(lonsin1 < lon_0-180, lonsin1+360 ,lonsin1)
4780-
londiff = np.abs(lonsin1[0:-1]-lonsin1[1:])
4781-
londiff_sort = np.sort(londiff)
4782-
thresh = 360.-londiff_sort[-2]
4783-
itemindex = nlons-np.where(londiff>=thresh)[0]
4780+
if nlons > 1:
4781+
londiff = np.abs(lonsin1[0:-1]-lonsin1[1:])
4782+
londiff_sort = np.sort(londiff)
4783+
thresh = 360.-londiff_sort[-2] if nlons > 2 else 360.-londiff_sort[-1]
4784+
itemindex = nlons-np.where(londiff>=thresh)[0]
4785+
else:
4786+
lonsin[0, :] = lonsin1
4787+
itemindex = 0
4788+
47844789
# if no shift necessary, itemindex will be
47854790
# empty, so don't do anything
47864791
if itemindex:
@@ -4822,10 +4827,15 @@ def shiftdata(self,lonsin,datain=None,lon_0=None):
48224827
nlons = len(lonsin)
48234828
lonsin = np.where(lonsin > lon_0+180, lonsin-360 ,lonsin)
48244829
lonsin = np.where(lonsin < lon_0-180, lonsin+360 ,lonsin)
4825-
londiff = np.abs(lonsin[0:-1]-lonsin[1:])
4826-
londiff_sort = np.sort(londiff)
4827-
thresh = 360.-londiff_sort[-2]
4828-
itemindex = len(lonsin)-np.where(londiff>=thresh)[0]
4830+
4831+
if nlons > 1:
4832+
londiff = np.abs(lonsin[0:-1]-lonsin[1:])
4833+
londiff_sort = np.sort(londiff)
4834+
thresh = 360.-londiff_sort[-2] if nlons > 2 else 360.0 - londiff_sort[-1]
4835+
itemindex = len(lonsin)-np.where(londiff>=thresh)[0]
4836+
else:
4837+
itemindex = 0
4838+
48294839
if itemindex:
48304840
# check to see if cyclic (wraparound) point included
48314841
# if so, remove it.

lib/mpl_toolkits/basemap/test.py

Lines changed: 57 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def make_array(self):
1313
u = np.ones((len(lat), len(lon)))
1414
v = np.zeros((len(lat), len(lon)))
1515
return u,v,lat,lon
16-
16+
1717
def test_cylindrical(self):
1818
# Cylindrical case
1919
B = Basemap()
@@ -22,7 +22,7 @@ def test_cylindrical(self):
2222
# Check that the vectors are identical.
2323
assert_almost_equal(ru, u)
2424
assert_almost_equal(rv, v)
25-
25+
2626
def test_nan(self):
2727
B = Basemap()
2828
u,v,lat,lon=self.make_array()
@@ -32,12 +32,12 @@ def test_nan(self):
3232
assert not np.isnan(ru).any()
3333
assert_almost_equal(u, ru)
3434
assert_almost_equal(v, rv)
35-
35+
3636
def test_npstere(self):
3737
# NP Stereographic case
3838
B=Basemap(projection='npstere', boundinglat=50., lon_0=0.)
3939
u,v,lat,lon=self.make_array()
40-
v = np.ones((len(lat), len(lon)))
40+
v = np.ones((len(lat), len(lon)))
4141
ru, rv = B.rotate_vector(u,v, lon, lat)
4242
assert_almost_equal(ru[2, :],[1,-1,-1,1], 6)
4343
assert_almost_equal(rv[2, :],[1,1,-1,-1], 6)
@@ -96,6 +96,59 @@ def test_no_cyc2(self):
9696
assert (grid==gridout).all()
9797

9898

99+
class TestShiftdata(TestCase):
100+
101+
def _get_2d_lons(self, lons1d):
102+
"""
103+
Generate a 2d grid
104+
"""
105+
lats = [10, ] * len(lons1d)
106+
return np.meshgrid(lons1d, lats)[0]
107+
108+
def test_2_points_should_work(self):
109+
"""
110+
Shiftdata should work with 2 points
111+
"""
112+
bm = Basemap(llcrnrlon=0, llcrnrlat=-80, urcrnrlon=360, urcrnrlat=80, projection='mill')
113+
114+
lons_expected = [10, 15, 20]
115+
lonsout = bm.shiftdata(lons_expected[:])
116+
assert_almost_equal(lons_expected, lonsout)
117+
118+
lonsout_expected = bm.shiftdata([10, 361, 362])
119+
lonsout = bm.shiftdata([10, 361])
120+
assert_almost_equal(lonsout_expected[:len(lonsout)], lonsout)
121+
122+
def test_1_point_should_work(self):
123+
bm = Basemap(llcrnrlon=0, llcrnrlat=-80, urcrnrlon=360, urcrnrlat=80, projection='mill')
124+
125+
# should not fail
126+
lonsout = bm.shiftdata([361])
127+
assert_almost_equal(lonsout, [1.0,])
128+
129+
lonsout = bm.shiftdata([10])
130+
assert_almost_equal(lonsout, [10.0,])
131+
132+
lonsin = np.array([361.0])
133+
lonsin.shape = (1, 1)
134+
lonsout = bm.shiftdata(lonsin[:])
135+
assert_almost_equal(lonsout.squeeze(), [1.0,])
136+
137+
def test_less_than_n_by_3_points_should_work(self):
138+
bm = Basemap(llcrnrlon=0, llcrnrlat=-80, urcrnrlon=360, urcrnrlat=80, projection='mill')
139+
lons_expected = self._get_2d_lons([10, 15, 20])
140+
141+
# nothing should change
142+
lonsout = bm.shiftdata(lons_expected)
143+
assert_almost_equal(lons_expected, lonsout)
144+
145+
# shift n x 3 and n x 2 grids and compare results over overlapping region
146+
lonsin = self._get_2d_lons([10, 361, 362])
147+
lonsout_expected = bm.shiftdata(lonsin[:])[:, :2]
148+
lonsout = bm.shiftdata(lonsin[:, :2])
149+
assert_almost_equal(lonsout_expected, lonsout)
150+
151+
99152
class TestProjectCoords(TestCase):
100153
def get_data(self):
101154
lons, lats = np.arange(-180, 180, 20), np.arange(-90, 90, 10)
@@ -130,7 +183,6 @@ def test_results_should_be_same_for_c_and_f_order_arrays(self):
130183

131184

132185

133-
134186
def test():
135187
"""
136188
Run some tests.

0 commit comments

Comments
 (0)