Skip to content

Commit 0137c6f

Browse files
committed
Tweaked Tile.show to be injected only when rf_ipython is imported and
running in ipython environment.
1 parent 9aee29e commit 0137c6f

File tree

2 files changed

+63
-52
lines changed

2 files changed

+63
-52
lines changed

pyrasterframes/src/main/python/pyrasterframes/rf_ipython.py

Lines changed: 63 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,49 @@
1919
#
2020

2121
import pyrasterframes.rf_types
22+
import numpy as np
23+
24+
def plot_tile(tile, lower_percentile=1, upper_percentile=99, axis=None, **imshow_args):
25+
"""
26+
Display an image of the tile
27+
28+
Parameters
29+
----------
30+
lower_percentile: between 0 and 100 inclusive.
31+
Specifies to clip values below this percentile
32+
upper_percentile: between 0 and 100 inclusive.
33+
Specifies to clip values above this percentile
34+
axis : matplotlib axis object to plot onto. Creates new axis if None
35+
imshow_args : parameters to pass into matplotlib.pyplot.imshow
36+
see https://matplotlib.org/3.1.1/api/_as_gen/matplotlib.pyplot.imshow.html
37+
Returns
38+
-------
39+
created or modified axis object
40+
"""
41+
42+
if axis is None:
43+
import matplotlib.pyplot as plt
44+
axis = plt.gca()
45+
46+
arr = tile.cells
47+
48+
def normalize_cells(cells, lower_percentile=lower_percentile, upper_percentile=upper_percentile):
49+
assert upper_percentile > lower_percentile, 'invalid upper and lower percentiles'
50+
lower = np.percentile(cells, lower_percentile)
51+
upper = np.percentile(cells, upper_percentile)
52+
cells_clipped = np.clip(cells, lower, upper)
53+
return (cells_clipped - lower) / (upper - lower)
2254

55+
axis.set_aspect('equal')
56+
axis.xaxis.set_ticks([])
57+
axis.yaxis.set_ticks([])
58+
59+
axis.imshow(normalize_cells(arr), **imshow_args)
60+
61+
return axis
2362

24-
def tile_to_png(tile, fig_size=None):
63+
64+
def tile_to_png(tile, lower_percentile=1, upper_percentile=99, title=None, fig_size=None):
2565
""" Provide image of Tile."""
2666
if tile.cells is None:
2767
return None
@@ -31,23 +71,24 @@ def tile_to_png(tile, fig_size=None):
3171
from matplotlib.figure import Figure
3272

3373
# Set up matplotlib objects
34-
nominal_size = 2 # approx full size for a 256x256 tile
74+
nominal_size = 4 # approx full size for a 256x256 tile
3575
if fig_size is None:
3676
fig_size = (nominal_size, nominal_size)
3777

3878
fig = Figure(figsize=fig_size)
3979
canvas = FigureCanvas(fig)
4080
axis = fig.add_subplot(1, 1, 1)
4181

42-
data = tile.cells
43-
44-
axis.imshow(data)
82+
plot_tile(tile, lower_percentile, upper_percentile, axis=axis)
4583
axis.set_aspect('equal')
4684
axis.xaxis.set_ticks([])
4785
axis.yaxis.set_ticks([])
4886

49-
axis.set_title('{}, {}'.format(tile.dimensions(), tile.cell_type.__repr__()),
50-
fontsize=fig_size[0]*4) # compact metadata as title
87+
if title is None:
88+
axis.set_title('{}, {}'.format(tile.dimensions(), tile.cell_type.__repr__()),
89+
fontsize=fig_size[0]*4) # compact metadata as title
90+
else:
91+
axis.set_title(title, fontsize=fig_size[0]*4) # compact metadata as title
5192

5293
with io.BytesIO() as output:
5394
canvas.print_png(output)
@@ -58,7 +99,7 @@ def tile_to_html(tile, fig_size=None):
5899
""" Provide HTML string representation of Tile image."""
59100
import base64
60101
b64_img_html = '<img src="data:image/png;base64,{}" />'
61-
png_bits = tile_to_png(tile, fig_size)
102+
png_bits = tile_to_png(tile, fig_size=fig_size)
62103
b64_png = base64.b64encode(png_bits).decode('utf-8').replace('\n', '')
63104
return b64_img_html.format(b64_png)
64105

@@ -102,6 +143,18 @@ def _safe_tile_to_html(t):
102143
pd.set_option('display.max_colwidth', default_max_colwidth)
103144
return return_html
104145

146+
147+
try:
148+
from IPython import get_ipython
149+
# modifications to currently running ipython session, if we are in one; these enable nicer visualization for Pandas
150+
if get_ipython() is not None:
151+
import pandas
152+
html_formatter = get_ipython().display_formatter.formatters['text/html']
153+
html_formatter.for_type(pandas.DataFrame, pandas_df_to_html)
154+
except ImportError:
155+
pass
156+
157+
105158
def spark_df_to_markdown(df, num_rows=5, truncate=True, vertical=False):
106159
from pyrasterframes import RFContext
107160
return RFContext.active().call("_dfToMarkdown", df._jdf, num_rows, truncate)
@@ -122,14 +175,13 @@ def spark_df_to_markdown(df, num_rows=5, truncate=True, vertical=False):
122175
markdown_formatter = ip.display_formatter.formatters['text/markdown']
123176
html_formatter.for_type(pyspark.sql.DataFrame, spark_df_to_markdown)
124177

125-
Tile.show = lambda t: display_png(t._repr_png_(), raw=True)
178+
Tile.show = lambda tile, lower_percentile=1, upper_percentile=99, axis=None, **imshow_args: \
179+
plot_tile(tile, lower_percentile, upper_percentile, axis, **imshow_args)
126180

127181
# See if we're in documentation mode and register a custom show implementation.
128182
if 'InProcessInteractiveShell' in ip.__class__.__name__:
129183
pyspark.sql.DataFrame._repr_markdown_ = spark_df_to_markdown
130184
pyspark.sql.DataFrame.show = lambda df, num_rows=5, truncate=True: display_markdown(spark_df_to_markdown(df, num_rows, truncate), raw=True)
131185

132186
except ImportError as e:
133-
print(e)
134-
raise e
135187
pass

pyrasterframes/src/main/python/pyrasterframes/rf_types.py

Lines changed: 0 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -363,47 +363,6 @@ def _repr_png_(self):
363363
from pyrasterframes.rf_ipython import tile_to_png
364364
return tile_to_png(self)
365365

366-
def show(self, lower_percentile=1, upper_percentile=99, axis=None, **imshow_args):
367-
"""
368-
Display an image of the tile
369-
370-
Parameters
371-
----------
372-
lower_percentile: between 0 and 100 inclusive.
373-
Specifies to clip values below this percentile
374-
upper_percentile: between 0 and 100 inclusive.
375-
Specifies to clip values above this percentile
376-
axis : matplotlib axis object to plot onto. Creates new axis if None
377-
imshow_args : parameters to pass into matplotlib.pyplot.imshow
378-
see https://matplotlib.org/3.1.1/api/_as_gen/matplotlib.pyplot.imshow.html
379-
Returns
380-
-------
381-
created or modified axis object
382-
"""
383-
384-
if axis is None:
385-
import matplotlib.pyplot as plt
386-
axis = plt.gca()
387-
388-
arr = self.cells
389-
390-
def normalize_cells(cells, lower_percentile=lower_percentile, upper_percentile=upper_percentile):
391-
assert upper_percentile > lower_percentile, 'invalid upper and lower percentiles'
392-
lower = np.percentile(cells, lower_percentile)
393-
upper = np.percentile(cells, upper_percentile)
394-
cells_clipped = np.clip(cells, lower, upper)
395-
return (cells_clipped - lower) / (upper - lower)
396-
397-
axis.set_aspect('equal')
398-
axis.xaxis.set_ticks([])
399-
axis.yaxis.set_ticks([])
400-
401-
axis.imshow(normalize_cells(arr), **imshow_args)
402-
403-
return axis
404-
405-
406-
407366

408367
class TileUDT(UserDefinedType):
409368
@classmethod

0 commit comments

Comments
 (0)