@@ -26,6 +26,8 @@ class here provides the PyRasterFrames entry point.
2626"""
2727from itertools import product
2828import functools , math
29+
30+ import pyproj
2931from pyspark import SparkContext
3032from pyspark .sql import DataFrame , Column
3133from pyspark .sql .types import (UserDefinedType , StructType , StructField , BinaryType , DoubleType , ShortType , IntegerType , StringType )
@@ -42,7 +44,8 @@ class here provides the PyRasterFrames entry point.
4244
4345from typing import List , Tuple
4446
45- __all__ = ['RasterFrameLayer' , 'Tile' , 'TileUDT' , 'CellType' , 'Extent' , 'CRS' , 'RasterSourceUDT' , 'TileExploder' , 'NoDataFilter' ]
47+ __all__ = ['RasterFrameLayer' , 'Tile' , 'TileUDT' , 'CellType' , 'Extent' ,
48+ 'CRS' , 'CrsUDT' , 'RasterSourceUDT' , 'TileExploder' , 'NoDataFilter' ]
4649
4750
4851class cached_property (object ):
@@ -227,7 +230,12 @@ def __str__(self):
227230class CRS (object ):
228231 # NB: The name `crsProj4` has to match what's used in StandardSerializers.crsSerializers
229232 def __init__ (self , crsProj4 ):
230- self .crsProj4 = crsProj4
233+ if isinstance (crsProj4 , pyproj .CRS ):
234+ self .crsProj4 = crsProj4 .to_proj4 ()
235+ elif isinstance (crsProj4 , str ):
236+ self .crsProj4 = crsProj4
237+ else :
238+ raise ValueError ('Unexpected CRS definition type: {}' .format (type (crsProj4 )))
231239
232240 @cached_property
233241 def __jvm__ (self ):
@@ -242,9 +250,13 @@ def proj4_str(self):
242250 """Alias for `crsProj4`"""
243251 return self .crsProj4
244252
253+ def __eq__ (self , other ):
254+ return isinstance (other , CRS ) and self .crsProj4 == other .crsProj4
255+
245256
246257class CellType (object ):
247258 def __init__ (self , cell_type_name ):
259+ assert (isinstance (cell_type_name , str ))
248260 self .cell_type_name = cell_type_name
249261
250262 @classmethod
@@ -443,29 +455,34 @@ def sqlType(cls):
443455 """
444456 Mirrors `schema` in scala companion object org.apache.spark.sql.rf.TileUDT
445457 """
458+ extent = StructType ([
459+ StructField ("xmin" ,DoubleType (), True ),
460+ StructField ("ymin" ,DoubleType (), True ),
461+ StructField ("xmax" ,DoubleType (), True ),
462+ StructField ("ymax" ,DoubleType (), True )
463+ ])
464+ subgrid = StructType ([
465+ StructField ("colMin" , IntegerType (), True ),
466+ StructField ("rowMin" , IntegerType (), True ),
467+ StructField ("colMax" , IntegerType (), True ),
468+ StructField ("rowMax" , IntegerType () ,True )
469+ ])
470+
471+ ref = StructType ([
472+ StructField ("source" , StructType ([
473+ StructField ("raster_source_kryo" , BinaryType (), False )
474+ ]),True ),
475+ StructField ("bandIndex" , IntegerType (), True ),
476+ StructField ("subextent" , extent ,True ),
477+ StructField ("subgrid" , subgrid , True ),
478+ ])
479+
446480 return StructType ([
447- StructField ("cell_context" , StructType ([
448- StructField ("cellType" , StructType ([
449- StructField ("cellTypeName" , StringType (), False )
450- ]), False ),
451- StructField ("dimensions" , StructType ([
452- StructField ("cols" , ShortType (), False ),
453- StructField ("rows" , ShortType (), False )
454- ]), False ),
455- ]), False ),
456- StructField ("cell_data" , StructType ([
457- StructField ("cells" , BinaryType (), True ),
458- StructField ("ref" , StructType ([
459- StructField ("source" , RasterSourceUDT (), False ),
460- StructField ("bandIndex" , IntegerType (), False ),
461- StructField ("subextent" , StructType ([
462- StructField ("xmin" , DoubleType (), False ),
463- StructField ("ymin" , DoubleType (), False ),
464- StructField ("xmax" , DoubleType (), False ),
465- StructField ("ymax" , DoubleType (), False )
466- ]), True )
467- ]), True )
468- ]), False )
481+ StructField ("cellType" , StringType (), False ),
482+ StructField ("cols" , IntegerType (), False ),
483+ StructField ("rows" , IntegerType (), False ),
484+ StructField ("cells" , BinaryType (), True ),
485+ StructField ("ref" , ref , True )
469486 ])
470487
471488 @classmethod
@@ -478,20 +495,14 @@ def scalaUDT(cls):
478495
479496 def serialize (self , tile ):
480497 cells = bytearray (tile .cells .flatten ().tobytes ())
481- row = [
482- # cell_context
483- [
484- [tile .cell_type .cell_type_name ],
485- tile .dimensions ()
486- ],
487- # cell_data
488- [
489- # cells
490- cells ,
491- None
492- ]
498+ dims = tile .dimensions ()
499+ return [
500+ tile .cell_type .cell_type_name ,
501+ dims [0 ],
502+ dims [1 ],
503+ cells ,
504+ None
493505 ]
494- return row
495506
496507 def deserialize (self , datum ):
497508 """
@@ -500,21 +511,21 @@ def deserialize(self, datum):
500511 :return: A Tile object from row data.
501512 """
502513
503- cell_data_bytes = datum .cell_data . cells
514+ cell_data_bytes = datum .cells
504515 if cell_data_bytes is None :
505- if datum .cell_data . ref is None :
516+ if datum .ref is None :
506517 raise Exception ("Invalid Tile structure. Missing cells and reference" )
507518 else :
508- payload = datum .cell_data . ref
519+ payload = datum .ref
509520 ref = RFContext .active ()._resolve_raster_ref (payload )
510521 cell_type = CellType (ref .cellType ().name ())
511522 cols = ref .cols ()
512523 rows = ref .rows ()
513524 cell_data_bytes = ref .tile ().toBytes ()
514525 else :
515- cell_type = CellType (datum .cell_context . cellType . cellTypeName )
516- cols = datum .cell_context . dimensions . cols
517- rows = datum .cell_context . dimensions . rows
526+ cell_type = CellType (datum .cellType )
527+ cols = datum .cols
528+ rows = datum .rows
518529
519530 if cell_data_bytes is None :
520531 raise Exception ("Unable to fetch cell data from: " + repr (datum ))
@@ -540,6 +551,34 @@ def deserialize(self, datum):
540551Tile .__UDT__ = TileUDT ()
541552
542553
554+ class CrsUDT (UserDefinedType ):
555+ @classmethod
556+ def sqlType (cls ):
557+ """
558+ Mirrors `schema` in scala companion object org.apache.spark.sql.rf.CrsUDT
559+ """
560+ return StringType ()
561+
562+ @classmethod
563+ def module (cls ):
564+ return 'pyrasterframes.rf_types'
565+
566+ @classmethod
567+ def scalaUDT (cls ):
568+ return 'org.apache.spark.sql.rf.CrsUDT'
569+
570+ def serialize (self , crs ):
571+ return crs .proj4_str
572+
573+ def deserialize (self , datum ):
574+ return CRS (datum )
575+
576+ deserialize .__safe_for_unpickling__ = True
577+
578+
579+ CRS .__UDT__ = CrsUDT ()
580+
581+
543582class TileExploder (JavaTransformer , DefaultParamsReadable , DefaultParamsWritable ):
544583 """
545584 Python wrapper for TileExploder.scala
0 commit comments