Skip to content

Commit 0506fb1

Browse files
authored
Merge pull request #294 from s22s/feature/ipython-display-tweaks
Added Markdown and HTML rendering of Spark DataFrames
2 parents 1ea29f2 + a011186 commit 0506fb1

File tree

24 files changed

+894
-772
lines changed

24 files changed

+894
-772
lines changed

core/src/main/resources/reference.conf

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ rasterframes {
33
prefer-gdal = true
44
showable-tiles = true
55
showable-max-cells = 20
6+
max-truncate-row-element-length = 40
67
raster-source-cache-timeout = 120 seconds
78
}
89

core/src/main/scala/org/locationtech/rasterframes/util/package.scala

Lines changed: 44 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
3838
import org.apache.spark.sql.catalyst.rules.Rule
3939
import org.apache.spark.sql.functions._
4040
import org.apache.spark.sql.rf._
41-
import org.apache.spark.sql.types.StringType
41+
import org.apache.spark.sql.types.{StringType, StructField}
4242
import org.apache.spark.sql._
4343
import org.slf4j.LoggerFactory
4444
import spire.syntax.cfor._
@@ -184,24 +184,58 @@ package object util {
184184
}
185185
}
186186

187+
private val truncateWidth = rfConfig.getInt("max-truncate-row-element-length")
188+
187189
implicit class DFWithPrettyPrint(val df: Dataset[_]) extends AnyVal {
190+
191+
def stringifyRowElements(cols: Seq[StructField], truncate: Boolean) = {
192+
cols
193+
.map(c => s"`${c.name}`")
194+
.map(c => df.col(c).cast(StringType))
195+
.map(c => if (truncate) {
196+
when(length(c) > lit(truncateWidth), concat(substring(c, 1, truncateWidth), lit("...")))
197+
.otherwise(c)
198+
} else c)
199+
}
200+
188201
def toMarkdown(numRows: Int = 5, truncate: Boolean = false): String = {
189202
import df.sqlContext.implicits._
190-
val cols = df.columns
191-
val header = cols.mkString("| ", " | ", " |") + "\n" + ("|---" * cols.length) + "|\n"
192-
val stringifiers = cols
193-
.map(c => s"`$c`")
194-
.map(c => df.col(c).cast(StringType))
195-
.map(c => if (truncate) substring(c, 1, 40) else c)
203+
val cols = df.schema.fields
204+
val header = cols.map(_.name).mkString("| ", " | ", " |") + "\n" + ("|---" * cols.length) + "|\n"
205+
val stringifiers = stringifyRowElements(cols, truncate)
196206
val cat = concat_ws(" | ", stringifiers: _*)
197-
val body = df
198-
.select(cat).limit(numRows)
207+
val rows = df
208+
.select(cat)
209+
.limit(numRows)
199210
.as[String]
200211
.collect()
201212
.map(_.replaceAll("\\[", "\\\\["))
202213
.map(_.replace('\n', '↩'))
214+
215+
val body = rows
203216
.mkString("| ", " |\n| ", " |")
204-
header + body
217+
218+
val caption = if (rows.length >= numRows) s"\n_Showing only top $numRows rows_.\n\n" else ""
219+
caption + header + body
220+
}
221+
222+
def toHTML(numRows: Int = 5, truncate: Boolean = false): String = {
223+
import df.sqlContext.implicits._
224+
val cols = df.schema.fields
225+
val header = "<thead>\n" + cols.map(_.name).mkString("<tr><th>", "</th><th>", "</th></tr>\n") + "</thead>\n"
226+
val stringifiers = stringifyRowElements(cols, truncate)
227+
val cat = concat_ws("</td><td>", stringifiers: _*)
228+
val rows = df
229+
.select(cat).limit(numRows)
230+
.as[String]
231+
.collect()
232+
233+
val body = rows
234+
.mkString("<tr><td>", "</td></tr>\n<tr><td>", "</td></tr>\n")
235+
236+
val caption = if (rows.length >= numRows) s"<caption>Showing only top $numRows rows</caption>\n" else ""
237+
238+
"<table>\n" + caption + header + "<tbody>\n" + body + "</tbody>\n" + "</table>"
205239
}
206240
}
207241

core/src/test/scala/org/locationtech/rasterframes/ExtensionMethodSpec.scala

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@ import geotrellis.raster.{ByteCellType, GridBounds, TileLayout}
2626
import geotrellis.spark.tiling.{CRSWorldExtent, LayoutDefinition}
2727
import geotrellis.spark.{KeyBounds, SpatialKey, TileLayerMetadata}
2828
import org.apache.spark.sql.Encoders
29-
import org.locationtech.rasterframes.util.SubdivideSupport
29+
import org.locationtech.rasterframes.util._
30+
31+
import scala.xml.parsing.XhtmlParser
3032

3133
/**
3234
* Tests miscellaneous extension methods.
@@ -111,8 +113,18 @@ class ExtensionMethodSpec extends TestEnvironment with TestData with SubdivideSu
111113
}
112114

113115
it("should render Markdown") {
114-
import org.locationtech.rasterframes.util._
115-
rf.toMarkdown().count(_ == '|') shouldBe >=(3 * 5)
116+
val md = rf.toMarkdown()
117+
md.count(_ == '|') shouldBe >=(3 * 5)
118+
md.count(_ == '\n') should be >=(6)
119+
120+
val md2 = rf.toMarkdown(truncate=true)
121+
md2 should include ("...")
122+
}
123+
124+
it("should render HTML") {
125+
noException shouldBe thrownBy {
126+
XhtmlParser(scala.io.Source.fromString(rf.toHTML()))
127+
}
116128
}
117129
}
118130
}

docs/src/main/paradox/_template/page.st

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@
3333
.md-clear { clear: both; }
3434
table { font-size: 80%; }
3535
code { font-size: 0.75em !important; }
36+
table a {
37+
word-break: break-all;
38+
}
3639
</style>
3740
</head>
3841

docs/src/main/paradox/release-notes.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
### 0.8.1
66

7+
* Added `toMarkdown()` and `toHTML()` extension methods for `DataFrame`, and registered them with the IPython formatter system when `rf_ipython` is imported.
78
* Fixed: Removed false return type garauntee in cases where an `Expression` accepts either `Tile` or `ProjectedRasterTile` [(#295)](https://github.com/locationtech/rasterframes/issues/295)
89

910
### 0.8.0

pyrasterframes/src/main/python/docs/aggregation.pymd

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,16 @@ print(tiles[1]['tile'].cells)
3333

3434
We use the @ref:[`rf_tile_mean`](reference.md#rf-tile-mean) function to compute the _tile_ aggregate mean of cells in each row of column `tile`. The mean of each _tile_ is computed separately, so the first mean is 1.0 and the second mean is 3.0. Notice that the number of rows in the DataFrame is the same before and after the aggregation.
3535

36-
```python, tile_mean, results='raw'
37-
rf.select(F.col('id'), rf_tile_mean(F.col('tile'))).show()
36+
```python, tile_mean
37+
means = rf.select(F.col('id'), rf_tile_mean(F.col('tile')))
38+
means
3839
```
3940

4041
We use the @ref:[`rf_agg_mean`](reference.md#rf-agg-mean) function to compute the DataFrame aggregate, which averages 25 values of 1.0 and 25 values of 3.0, across the fifty cells in two rows. Note that only a single row is returned since the average is computed over the full DataFrame.
4142

42-
```python, agg_mean, results='raw'
43-
rf.agg(rf_agg_mean(F.col('tile'))).show()
43+
```python, agg_mean
44+
mean = rf.agg(rf_agg_mean(F.col('tile')))
45+
mean
4446
```
4547

4648
We use the @ref:[`rf_agg_local_mean`](reference.md#rf-agg-local-mean) function to compute the element-wise local aggregate mean across the two rows. For this aggregation, we are computing the mean of one value of 1.0 and one value of 3.0 to arrive at the element-wise mean, but doing so twenty-five times, one for each position in the _tile_.
@@ -57,11 +59,10 @@ print(t.cells)
5759

5860
We can also count the total number of data and NoData cells over all the _tiles_ in a DataFrame using @ref:[`rf_agg_data_cells`](reference.md#rf-agg-data-cells) and @ref:[`rf_agg_no_data_cells`](reference.md#rf-agg-no-data-cells). There are ~3.8 million data cells and ~1.9 million NoData cells in this DataFrame. See the section on @ref:["NoData" handling](nodata-handling.md) for additional discussion on handling missing data.
5961

60-
```python, cell_counts, results='raw'
62+
```python, cell_counts
6163
rf = spark.read.raster('https://s22s-test-geotiffs.s3.amazonaws.com/MCD43A4.006/11/05/2018233/MCD43A4.A2018233.h11v05.006.2018242035530_B02.TIF')
6264
stats = rf.agg(rf_agg_data_cells('proj_raster'), rf_agg_no_data_cells('proj_raster'))
63-
64-
stats.show()
65+
stats
6566
```
6667

6768
## Statistical Summaries
@@ -77,16 +78,16 @@ stats = rf.select(rf_tile_stats('proj_raster').alias('stats'))
7778
stats.printSchema()
7879
```
7980

80-
```python, show_stats, results='raw'
81-
stats.select('stats.min', 'stats.max', 'stats.mean', 'stats.variance').show(10, truncate=False)
81+
```python, show_stats
82+
stats.select('stats.min', 'stats.max', 'stats.mean', 'stats.variance')
8283
```
8384

8485
The @ref:[`rf_agg_stats`](reference.md#rf-agg-stats) function aggregates over all of the _tiles_ in a DataFrame and returns a statistical summary of all cell values as shown below.
8586

86-
```python, agg_stats, results='raw'
87-
rf.agg(rf_agg_stats('proj_raster').alias('stats')) \
88-
.select('stats.min', 'stats.max', 'stats.mean', 'stats.variance') \
89-
.show()
87+
```python, agg_stats
88+
stats = rf.agg(rf_agg_stats('proj_raster').alias('stats')) \
89+
.select('stats.min', 'stats.max', 'stats.mean', 'stats.variance')
90+
stats
9091
```
9192

9293
The @ref:[`rf_agg_local_stats`](reference.md#rf-agg-local-stats) function computes the element-wise local aggregate statistical summary as shown below. The DataFrame used in the previous two code blocks has unequal _tile_ dimensions, so a different DataFrame is used in this code block to avoid a runtime error.

pyrasterframes/src/main/python/docs/getting-started.pymd

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,17 +34,17 @@ spark = pyrasterframes.get_spark_session()
3434

3535
Then, you can read a raster and work with it in a Spark DataFrame.
3636

37-
```python, local_add, results='raw'
37+
```python, local_add
3838
from pyrasterframes.rasterfunctions import *
3939
from pyspark.sql.functions import lit
4040

4141
# Read a MODIS surface reflectance granule
4242
df = spark.read.raster('https://modis-pds.s3.amazonaws.com/MCD43A4.006/11/08/2019059/MCD43A4.A2019059.h11v08.006.2019072203257_B02.TIF')
4343

4444
# Add 3 element-wise, show some rows of the DataFrame
45-
df.withColumn('added', rf_local_add(df.proj_raster, lit(3))) \
46-
.select(rf_crs('added'), rf_extent('added'), rf_tile('added')) \
47-
.show(3)
45+
sample = df.withColumn('added', rf_local_add(df.proj_raster, lit(3))) \
46+
.select(rf_crs('added'), rf_extent('added'), rf_tile('added'))
47+
sample
4848
```
4949

5050
This example is extended in the [getting started Jupyter notebook](https://nbviewer.jupyter.org/github/locationtech/rasterframes/blob/develop/rf-notebook/src/main/notebooks/Getting%20Started.ipynb).

pyrasterframes/src/main/python/docs/index.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ The source code can be found on GitHub at [locationtech/rasterframes](https://gi
1010

1111
<img src="RasterFramePipeline.png" width="600px"/>
1212

13+
RasterFrames is released under the [Apache 2.0 License](https://github.com/locationtech/rasterframes/blob/develop/LICENSE).
14+
1315
<hr/>
1416

1517
@@@ div { .md-left}

pyrasterframes/src/main/python/docs/languages.pymd

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ red_nir_tiles_monthly_2017 = spark.read.raster(
5050

5151
### Step 4: Compute aggregates
5252

53-
```python, step_4_python, results='raw'
53+
```python, step_4_python
5454
result = red_nir_tiles_monthly_2017 \
5555
.where(st_intersects(
5656
st_reproject(rf_geometry(col('red')), rf_crs(col('red')).crsProj4, rf_mk_crs('EPSG:4326')),
@@ -60,7 +60,7 @@ result = red_nir_tiles_monthly_2017 \
6060
.agg(rf_agg_stats(rf_normalized_difference(col('nir'), col('red'))).alias('ndvi_stats')) \
6161
.orderBy(col('month')) \
6262
.select('month', 'ndvi_stats.*')
63-
result.show()
63+
result
6464
```
6565

6666
## SQL
@@ -80,14 +80,14 @@ sql("CREATE OR REPLACE TEMPORARY VIEW modis USING `aws-pds-modis-catalog`")
8080

8181
### Step 2: Down-select data by month
8282

83-
```python, step_2_sql, results='raw'
83+
```python, step_2_sql
8484
sql("""
8585
CREATE OR REPLACE TEMPORARY VIEW red_nir_monthly_2017 AS
8686
SELECT granule_id, month(acquisition_date) as month, B01 as red, B02 as nir
8787
FROM modis
8888
WHERE year(acquisition_date) = 2017 AND day(acquisition_date) = 15 AND granule_id = 'h21v09'
8989
""")
90-
sql('DESCRIBE red_nir_monthly_2017').show()
90+
sql('DESCRIBE red_nir_monthly_2017')
9191
```
9292

9393
### Step 3: Read tiles
@@ -106,16 +106,17 @@ OPTIONS (
106106

107107
### Step 4: Compute aggregates
108108

109-
```python, step_4_sql, results='raw'
110-
sql("""
109+
```python, step_4_sql
110+
grouped = sql("""
111111
SELECT month, ndvi_stats.* FROM (
112112
SELECT month, rf_agg_stats(rf_normalized_difference(nir, red)) as ndvi_stats
113113
FROM red_nir_tiles_monthly_2017
114114
WHERE st_intersects(st_reproject(rf_geometry(red), rf_crs(red), 'EPSG:4326'), st_makePoint(34.870605, -4.729727))
115115
GROUP BY month
116116
ORDER BY month
117117
)
118-
""").show()
118+
""")
119+
grouped
119120
```
120121

121122
## Scala

0 commit comments

Comments
 (0)