Skip to content

Commit 92c70f2

Browse files
committed
Fix for #72.
Misc refactoring of Python RFContext methods. Signed-off-by: Simeon H.K. Fitch <[email protected]>
1 parent b945b8e commit 92c70f2

File tree

8 files changed

+84
-53
lines changed

8 files changed

+84
-53
lines changed

pyrasterframes/build.sbt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ ivyPaths in pysparkCmd := ivyPaths.value.withIvyHome(target.value / "ivy")
153153
pyTest := {
154154
val _ = assembly.value
155155
val s = streams.value
156+
s.log.info("Running python tests...")
156157
val wd = pythonSource.value
157158
Process("python setup.py test", wd) ! s.log
158159
}

pyrasterframes/python/geomesa_pyspark/spark.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,11 @@
1212
from pyspark.sql.types import UserDefinedType
1313
from pyspark.sql import Row
1414
from pyspark.sql.types import *
15-
from pyrasterframes.context import _checked_context
15+
from pyrasterframes.context import RFContext
1616

1717

1818
__all__ = ['GeometryUDT']
1919

20-
2120
class GeometryUDT(UserDefinedType):
2221
"""User-defined type (UDT).
2322
@@ -41,4 +40,4 @@ def serialize(self, obj):
4140
return Row(obj.toBytes)
4241

4342
def deserialize(self, datum):
44-
return _checked_context().generateGeometry(datum[0])
43+
return RFContext._jvm_mirror().generateGeometry(datum[0])

pyrasterframes/python/pyrasterframes/__init__.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,15 @@
33
appended to PySpark classes.
44
"""
55

6-
76
from __future__ import absolute_import
8-
from pyspark.sql.types import UserDefinedType
97
from pyspark import SparkContext
108
from pyspark.sql import SparkSession, DataFrame, DataFrameReader
11-
from pyspark.sql.types import *
129
from pyspark.sql.column import _to_java_column
1310

1411
# Import RasterFrame types and functions
1512
from .types import *
1613
from . import rasterfunctions
17-
14+
from .context import RFContext
1815

1916
__all__ = ['RasterFrame', 'TileExploder']
2017

pyrasterframes/python/pyrasterframes/context.py

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,42 @@
44

55
from pyspark import SparkContext
66

7-
def _checked_context():
8-
""" Get the active SparkContext and throw an error if it is not enabled for RasterFrames."""
9-
sc = SparkContext._active_spark_context
10-
if not hasattr(sc, '_rf_context'):
11-
raise AttributeError(
12-
"RasterFrames have not been enabled for the active session. Call 'SparkSession.withRasterFrames()'.")
13-
return sc._rf_context._jrfctx
7+
__all__ = ['RFContext']
8+
9+
10+
class RFContext(object):
11+
"""
12+
Entrypoint to RasterFrames services
13+
"""
14+
def __init__(self, spark_session):
15+
self._spark_session = spark_session
16+
self._gateway = spark_session.sparkContext._gateway
17+
self._jvm = self._gateway.jvm
18+
jsess = self._spark_session._jsparkSession
19+
self._jrfctx = self._jvm.astraea.spark.rasterframes.py.PyRFContext(jsess)
20+
21+
def list_to_seq(self, py_list):
22+
conv = self.lookup('listToSeq')
23+
return conv(py_list)
24+
25+
def lookup(self, function_name):
26+
return getattr(self._jrfctx, function_name)
27+
28+
@staticmethod
29+
def active():
30+
"""
31+
Get the active Pythono RFContext and throw an error if it is not enabled for RasterFrames.
32+
"""
33+
sc = SparkContext._active_spark_context
34+
if not hasattr(sc, '_rf_context'):
35+
raise AttributeError(
36+
"RasterFrames have not been enabled for the active session. Call 'SparkSession.withRasterFrames()'.")
37+
return sc._rf_context
38+
39+
@staticmethod
40+
def _jvm_mirror():
41+
"""
42+
Get the active Scala PyRFContext and throw an error if it is not enabled for RasterFrames.
43+
"""
44+
return RFContext.active()._jrfctx
45+

pyrasterframes/python/pyrasterframes/rasterfunctions.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,14 @@
88
from __future__ import absolute_import
99
from pyspark.sql.types import *
1010
from pyspark.sql.column import Column, _to_java_column
11-
from .context import _checked_context
11+
from .context import RFContext
1212

1313

1414
THIS_MODULE = 'pyrasterframes'
1515

1616

1717
def _context_call(name, *args):
18-
f = getattr(_checked_context(), name)
18+
f = RFContext.active().lookup(name)
1919
return f(*args)
2020

2121

@@ -27,8 +27,7 @@ def _celltype(cellTypeStr):
2727
def _create_assembleTile():
2828
""" Create a function mapping to the Scala implementation."""
2929
def _(colIndex, rowIndex, cellData, numCols, numRows, cellType):
30-
ctx = _checked_context()
31-
jfcn = getattr(ctx, 'assembleTile')
30+
jfcn = RFContext.active().lookup('assembleTile')
3231
return Column(jfcn(_to_java_column(colIndex), _to_java_column(rowIndex), _to_java_column(cellData), numCols, numRows, _celltype(cellType)))
3332
_.__name__ = 'assembleTile'
3433
_.__doc__ = "Create a Tile from a column of cell data with location indices"
@@ -39,7 +38,7 @@ def _(colIndex, rowIndex, cellData, numCols, numRows, cellType):
3938
def _create_arrayToTile():
4039
""" Create a function mapping to the Scala implementation."""
4140
def _(arrayCol, numCols, numRows):
42-
jfcn = getattr(_checked_context(), 'arrayToTile')
41+
jfcn = RFContext.active().lookup('arrayToTile')
4342
return Column(jfcn(_to_java_column(arrayCol), numCols, numRows))
4443
_.__name__ = 'arrayToTile'
4544
_.__doc__ = "Convert array in `arrayCol` into a Tile of dimensions `numCols` and `numRows'"
@@ -50,7 +49,7 @@ def _(arrayCol, numCols, numRows):
5049
def _create_convertCellType():
5150
""" Create a function mapping to the Scala implementation."""
5251
def _(tileCol, cellType):
53-
jfcn = getattr(_checked_context(), 'convertCellType')
52+
jfcn = RFContext.active().lookup('convertCellType')
5453
return Column(jfcn(_to_java_column(tileCol), _celltype(cellType)))
5554
_.__name__ = 'convertCellType'
5655
_.__doc__ = "Convert the numeric type of the Tiles in `tileCol`"
@@ -61,7 +60,7 @@ def _(tileCol, cellType):
6160
def _create_makeConstantTile():
6261
""" Create a function mapping to the Scala implementation."""
6362
def _(value, cols, rows, cellType):
64-
jfcn = getattr(_checked_context(), 'makeConstantTile')
63+
jfcn = RFContext.active().lookup('makeConstantTile')
6564
return Column(jfcn(value, cols, rows, cellType))
6665
_.__name__ = 'makeConstantTile'
6766
_.__doc__ = "Constructor for constant tile column"
@@ -72,7 +71,7 @@ def _(value, cols, rows, cellType):
7271
def _create_tileZeros():
7372
""" Create a function mapping to the Scala implementation."""
7473
def _(cols, rows, cellType = 'float64'):
75-
jfcn = getattr(_checked_context(), 'tileZeros')
74+
jfcn = RFContext.active().lookup('tileZeros')
7675
return Column(jfcn(cols, rows, cellType))
7776
_.__name__ = 'tileZeros'
7877
_.__doc__ = "Create column of constant tiles of zero"
@@ -83,7 +82,7 @@ def _(cols, rows, cellType = 'float64'):
8382
def _create_tileOnes():
8483
""" Create a function mapping to the Scala implementation."""
8584
def _(cols, rows, cellType = 'float64'):
86-
jfcn = getattr(_checked_context(), 'tileOnes')
85+
jfcn = RFContext.active().lookup('tileOnes')
8786
return Column(jfcn(cols, rows, cellType))
8887
_.__name__ = 'tileOnes'
8988
_.__doc__ = "Create column of constant tiles of one"
@@ -94,7 +93,7 @@ def _(cols, rows, cellType = 'float64'):
9493
def _create_rasterize():
9594
""" Create a function mapping to the Scala rasterize function. """
9695
def _(geometryCol, boundsCol, valueCol, numCols, numRows):
97-
jfcn = getattr(_checked_context(), 'rasterize')
96+
jfcn = RFContext.active().lookup('rasterize')
9897
return Column(jfcn(_to_java_column(geometryCol), _to_java_column(boundsCol), _to_java_column(valueCol), numCols, numRows))
9998
_.__name__ = 'rasterize'
10099
_.__doc__ = 'Create a tile where cells in the grid defined by cols, rows, and bounds are filled with the given value.'
@@ -104,7 +103,7 @@ def _(geometryCol, boundsCol, valueCol, numCols, numRows):
104103
def _create_reproject_geometry():
105104
""" Create a function mapping to the Scala reprojectGeometry function. """
106105
def _(geometryCol, srcCRSName, dstCRSName):
107-
jfcn = getattr(_checked_context(), 'reprojectGeometry')
106+
jfcn = RFContext.active().lookup('reprojectGeometry')
108107
return Column(jfcn(_to_java_column(geometryCol), srcCRSName, dstCRSName))
109108
_.__name__ = 'reprojectGeometry'
110109
_.__doc__ = """Reproject a column of geometry given the CRS names of the source and destination.
@@ -114,6 +113,17 @@ def _(geometryCol, srcCRSName, dstCRSName):
114113
_.__module__ = THIS_MODULE
115114
return _
116115

116+
def _create_explode_tiles():
117+
""" Create a function mapping to Scala explodeTiles function """
118+
def _(*args):
119+
jfcn = RFContext.active().lookup('explodeTiles')
120+
jcols = [_to_java_column(arg) for arg in args]
121+
return Column(jfcn(RFContext.active().list_to_seq(jcols)))
122+
_.__name__ = 'explodeTiles'
123+
_.__doc__ = 'Create a row for each cell in Tile.'
124+
_.__module__ = THIS_MODULE
125+
return _
126+
117127
_rf_unique_functions = {
118128
'assembleTile': _create_assembleTile(),
119129
'arrayToTile': _create_arrayToTile(),
@@ -123,7 +133,8 @@ def _(geometryCol, srcCRSName, dstCRSName):
123133
'tileOnes': _create_tileOnes(),
124134
'cellTypes': lambda: _context_call('cellTypes'),
125135
'rasterize': _create_rasterize(),
126-
'reprojectGeometry': _create_reproject_geometry()
136+
'reprojectGeometry': _create_reproject_geometry(),
137+
'explodeTiles': _create_explode_tiles()
127138
}
128139

129140

@@ -154,7 +165,6 @@ def _(geometryCol, srcCRSName, dstCRSName):
154165

155166
_rf_column_functions = {
156167
# ------- RasterFrames functions -------
157-
'explodeTiles': 'Create a row for each cell in Tile.',
158168
'tileDimensions': 'Query the number of (cols, rows) in a Tile.',
159169
'envelope': 'Extracts the bounding box (envelope) of the geometry.',
160170
'tileToIntArray': 'Flattens Tile into an array of integers.',
@@ -280,7 +290,7 @@ def _(geometryCol, srcCRSName, dstCRSName):
280290
def _create_column_function(name, doc=""):
281291
""" Create a mapping to Scala UDF for a column function by name"""
282292
def _(*args):
283-
jfcn = getattr(_checked_context(), name)
293+
jfcn = RFContext.active().lookup(name)
284294
jcols = [_to_java_column(arg) for arg in args]
285295
return Column(jfcn(*jcols))
286296
_.__name__ = name
@@ -292,7 +302,7 @@ def _(*args):
292302
def _create_columnScalarFunction(name, doc=""):
293303
""" Create a mapping to Scala UDF for a (column, scalar) -> column function by name"""
294304
def _(col, scalar):
295-
jfcn = getattr(_checked_context(), name)
305+
jfcn = RFContext.active().lookup(name)
296306
return Column(jfcn(_to_java_column(col), scalar))
297307
_.__name__ = name
298308
_.__doc__ = doc

pyrasterframes/python/pyrasterframes/types.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,9 @@ class here provides the PyRasterFrames entry point.
1111
from pyspark.sql.types import *
1212
from pyspark.ml.wrapper import JavaTransformer
1313
from pyspark.ml.util import JavaMLReadable, JavaMLWritable
14-
from .context import _checked_context
14+
from .context import RFContext
1515

16-
__all__ = ['RFContext', 'RasterFrame', 'TileUDT', 'TileExploder', 'NoDataFilter']
17-
18-
class RFContext(object):
19-
"""
20-
Entrypoint to RasterFrames services
21-
"""
22-
def __init__(self, spark_session):
23-
self._spark_session = spark_session
24-
self._gateway = spark_session.sparkContext._gateway
25-
self._jvm = self._gateway.jvm
26-
jsess = self._spark_session._jsparkSession
27-
self._jrfctx = self._jvm.astraea.spark.rasterframes.py.PyRFContext(jsess)
16+
__all__ = ['RasterFrame', 'TileUDT', 'TileExploder', 'NoDataFilter']
2817

2918

3019
class RasterFrame(DataFrame):
@@ -156,7 +145,7 @@ def serialize(self, obj):
156145
obj.toBytes)
157146

158147
def deserialize(self, datum):
159-
return _checked_context().generateTile(datum[0], datum[1], datum[2], datum[3])
148+
return RFContext._jvm_mirror().generateTile(datum[0], datum[1], datum[2], datum[3])
160149

161150

162151

pyrasterframes/python/tests/PyRasterFramesTests.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,12 +158,19 @@ def test_sql(self):
158158
self.assertTrue(_rounded_compare(statsRow.base, statsRow.double / 2))
159159
self.assertTrue(_rounded_compare(statsRow.base, statsRow.half * 2))
160160

161+
def test_explode(self):
162+
self.rf.select('spatial_key', explodeTiles(self.tileCol)).show()
163+
161164

162165
def suite():
163166
functionTests = unittest.TestSuite()
164167
functionTests.addTest(RasterFunctionsTest('test_identify_columns'))
168+
functionTests.addTest(RasterFunctionsTest('test_tile_operations'))
165169
functionTests.addTest(RasterFunctionsTest('test_general'))
170+
functionTests.addTest(RasterFunctionsTest('test_rasterize'))
171+
functionTests.addTest(RasterFunctionsTest('test_reproject'))
166172
functionTests.addTest(RasterFunctionsTest('test_aggregations'))
173+
functionTests.addTest(RasterFunctionsTest('test_explode'))
167174
functionTests.addTest(RasterFunctionsTest('test_sql'))
168175
return functionTests
169176

pyrasterframes/src/main/scala/astraea/spark/rasterframes/py/PyRFContext.scala

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,21 +20,15 @@ package astraea.spark.rasterframes.py
2020

2121
import astraea.spark.rasterframes._
2222
import astraea.spark.rasterframes.util.CRSParser
23-
2423
import com.vividsolutions.jts.geom.Geometry
25-
26-
import geotrellis.raster.{ArrayTile, CellType, Tile, MultibandTile}
27-
import geotrellis.spark.{SpatialKey, SpaceTimeKey, TileLayerMetadata, MultibandTileLayerRDD, ContextRDD}
24+
import geotrellis.raster.{ArrayTile, CellType, MultibandTile}
2825
import geotrellis.spark.io._
29-
30-
import org.locationtech.geomesa.spark.jts.util.WKBUtils
31-
32-
import org.apache.spark.rdd.RDD
26+
import geotrellis.spark.{ContextRDD, MultibandTileLayerRDD, SpaceTimeKey, SpatialKey, TileLayerMetadata}
3327
import org.apache.spark.sql._
34-
28+
import org.locationtech.geomesa.spark.jts.util.WKBUtils
3529
import spray.json._
36-
import astraea.spark.rasterframes.ml.NoDataFilter
3730

31+
import scala.collection.JavaConverters._
3832

3933
/**
4034
* py4j access wrapper to RasterFrame entry points.
@@ -193,4 +187,6 @@ class PyRFContext(implicit sparkSession: SparkSession) extends RasterFunctions
193187
val dst = CRSParser(dstName)
194188
reprojectGeometry(geometryCol, src, dst)
195189
}
190+
191+
def listToSeq(cols: java.util.ArrayList[AnyRef]): Seq[AnyRef] = cols.asScala
196192
}

0 commit comments

Comments
 (0)