Skip to content

Commit 1cd8df4

Browse files
committed
Implemented CrsUDT in Python.
1 parent 0af5101 commit 1cd8df4

File tree

7 files changed

+292
-209
lines changed

7 files changed

+292
-209
lines changed

core/src/main/scala/org/apache/spark/sql/rf/TileUDT.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class TileUDT extends UserDefinedType[Tile] {
4545
def userClass: Class[Tile] = classOf[Tile]
4646

4747
def sqlType: StructType = StructType(Seq(
48-
StructField("cell_type", StringType, false),
48+
StructField("cellType", StringType, false),
4949
StructField("cols", IntegerType, false),
5050
StructField("rows", IntegerType, false),
5151
StructField("cells", BinaryType, true),

core/src/main/scala/org/locationtech/rasterframes/expressions/package.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,6 @@ package object expressions {
8282
registry.registerExpression[GetCRS]("rf_crs")
8383
registry.registerExpression[RealizeTile]("rf_tile")
8484
registry.registerExpression[CreateProjectedRaster]("rf_proj_raster")
85-
registry.registerExpression[Subtract]("rf_local_subtract")
8685
registry.registerExpression[Multiply]("rf_local_multiply")
8786
registry.registerExpression[Divide]("rf_local_divide")
8887
registry.registerExpression[NormalizedDifference]("rf_normalized_difference")

core/src/main/scala/org/locationtech/rasterframes/ref/Subgrid.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,23 @@
1+
/*
2+
* This software is licensed under the Apache 2 license, quoted below.
3+
*
4+
* Copyright 2021 Azavea, Inc.
5+
*
6+
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
7+
* use this file except in compliance with the License. You may obtain a copy of
8+
* the License at
9+
*
10+
* [http://www.apache.org/licenses/LICENSE-2.0]
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
14+
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
15+
* License for the specific language governing permissions and limitations under
16+
* the License.
17+
*
18+
* SPDX-License-Identifier: Apache-2.0
19+
*
20+
*/
121
package org.locationtech.rasterframes.ref
222

323
import geotrellis.raster.GridBounds

pyrasterframes/src/main/python/pyrasterframes/rf_types.py

Lines changed: 82 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ class here provides the PyRasterFrames entry point.
2626
"""
2727
from itertools import product
2828
import functools, math
29+
30+
import pyproj
2931
from pyspark import SparkContext
3032
from pyspark.sql import DataFrame, Column
3133
from pyspark.sql.types import (UserDefinedType, StructType, StructField, BinaryType, DoubleType, ShortType, IntegerType, StringType)
@@ -42,7 +44,8 @@ class here provides the PyRasterFrames entry point.
4244

4345
from typing import List, Tuple
4446

45-
__all__ = ['RasterFrameLayer', 'Tile', 'TileUDT', 'CellType', 'Extent', 'CRS', 'RasterSourceUDT', 'TileExploder', 'NoDataFilter']
47+
__all__ = ['RasterFrameLayer', 'Tile', 'TileUDT', 'CellType', 'Extent',
48+
'CRS', 'CrsUDT', 'RasterSourceUDT', 'TileExploder', 'NoDataFilter']
4649

4750

4851
class cached_property(object):
@@ -227,7 +230,12 @@ def __str__(self):
227230
class CRS(object):
228231
# NB: The name `crsProj4` has to match what's used in StandardSerializers.crsSerializers
229232
def __init__(self, crsProj4):
230-
self.crsProj4 = crsProj4
233+
if isinstance(crsProj4, pyproj.CRS):
234+
self.crsProj4 = crsProj4.to_proj4()
235+
elif isinstance(crsProj4, str):
236+
self.crsProj4 = crsProj4
237+
else:
238+
raise ValueError('Unexpected CRS definition type: {}'.format(type(crsProj4)))
231239

232240
@cached_property
233241
def __jvm__(self):
@@ -242,9 +250,13 @@ def proj4_str(self):
242250
"""Alias for `crsProj4`"""
243251
return self.crsProj4
244252

253+
def __eq__(self, other):
254+
return isinstance(other, CRS) and self.crsProj4 == other.crsProj4
255+
245256

246257
class CellType(object):
247258
def __init__(self, cell_type_name):
259+
assert(isinstance(cell_type_name, str))
248260
self.cell_type_name = cell_type_name
249261

250262
@classmethod
@@ -443,29 +455,34 @@ def sqlType(cls):
443455
"""
444456
Mirrors `schema` in scala companion object org.apache.spark.sql.rf.TileUDT
445457
"""
458+
extent = StructType([
459+
StructField("xmin",DoubleType(), True),
460+
StructField("ymin",DoubleType(), True),
461+
StructField("xmax",DoubleType(), True),
462+
StructField("ymax",DoubleType(), True)
463+
])
464+
subgrid = StructType([
465+
StructField("colMin", IntegerType(), True),
466+
StructField("rowMin", IntegerType(), True),
467+
StructField("colMax", IntegerType(), True),
468+
StructField("rowMax", IntegerType() ,True)
469+
])
470+
471+
ref = StructType([
472+
StructField("source", StructType([
473+
StructField("raster_source_kryo", BinaryType(), False)
474+
]),True),
475+
StructField("bandIndex", IntegerType(), True),
476+
StructField("subextent", extent ,True),
477+
StructField("subgrid", subgrid, True),
478+
])
479+
446480
return StructType([
447-
StructField("cell_context", StructType([
448-
StructField("cellType", StructType([
449-
StructField("cellTypeName", StringType(), False)
450-
]), False),
451-
StructField("dimensions", StructType([
452-
StructField("cols", ShortType(), False),
453-
StructField("rows", ShortType(), False)
454-
]), False),
455-
]), False),
456-
StructField("cell_data", StructType([
457-
StructField("cells", BinaryType(), True),
458-
StructField("ref", StructType([
459-
StructField("source", RasterSourceUDT(), False),
460-
StructField("bandIndex", IntegerType(), False),
461-
StructField("subextent", StructType([
462-
StructField("xmin", DoubleType(), False),
463-
StructField("ymin", DoubleType(), False),
464-
StructField("xmax", DoubleType(), False),
465-
StructField("ymax", DoubleType(), False)
466-
]), True)
467-
]), True)
468-
]), False)
481+
StructField("cellType", StringType(), False),
482+
StructField("cols", IntegerType(), False),
483+
StructField("rows", IntegerType(), False),
484+
StructField("cells", BinaryType(), True),
485+
StructField("ref", ref, True)
469486
])
470487

471488
@classmethod
@@ -478,20 +495,14 @@ def scalaUDT(cls):
478495

479496
def serialize(self, tile):
480497
cells = bytearray(tile.cells.flatten().tobytes())
481-
row = [
482-
# cell_context
483-
[
484-
[tile.cell_type.cell_type_name],
485-
tile.dimensions()
486-
],
487-
# cell_data
488-
[
489-
# cells
490-
cells,
491-
None
492-
]
498+
dims = tile.dimensions()
499+
return [
500+
tile.cell_type.cell_type_name,
501+
dims[0],
502+
dims[1],
503+
cells,
504+
None
493505
]
494-
return row
495506

496507
def deserialize(self, datum):
497508
"""
@@ -500,21 +511,21 @@ def deserialize(self, datum):
500511
:return: A Tile object from row data.
501512
"""
502513

503-
cell_data_bytes = datum.cell_data.cells
514+
cell_data_bytes = datum.cells
504515
if cell_data_bytes is None:
505-
if datum.cell_data.ref is None:
516+
if datum.ref is None:
506517
raise Exception("Invalid Tile structure. Missing cells and reference")
507518
else:
508-
payload = datum.cell_data.ref
519+
payload = datum.ref
509520
ref = RFContext.active()._resolve_raster_ref(payload)
510521
cell_type = CellType(ref.cellType().name())
511522
cols = ref.cols()
512523
rows = ref.rows()
513524
cell_data_bytes = ref.tile().toBytes()
514525
else:
515-
cell_type = CellType(datum.cell_context.cellType.cellTypeName)
516-
cols = datum.cell_context.dimensions.cols
517-
rows = datum.cell_context.dimensions.rows
526+
cell_type = CellType(datum.cellType)
527+
cols = datum.cols
528+
rows = datum.rows
518529

519530
if cell_data_bytes is None:
520531
raise Exception("Unable to fetch cell data from: " + repr(datum))
@@ -540,6 +551,34 @@ def deserialize(self, datum):
540551
Tile.__UDT__ = TileUDT()
541552

542553

554+
class CrsUDT(UserDefinedType):
555+
@classmethod
556+
def sqlType(cls):
557+
"""
558+
Mirrors `schema` in scala companion object org.apache.spark.sql.rf.CrsUDT
559+
"""
560+
return StringType()
561+
562+
@classmethod
563+
def module(cls):
564+
return 'pyrasterframes.rf_types'
565+
566+
@classmethod
567+
def scalaUDT(cls):
568+
return 'org.apache.spark.sql.rf.CrsUDT'
569+
570+
def serialize(self, crs):
571+
return crs.proj4_str
572+
573+
def deserialize(self, datum):
574+
return CRS(datum)
575+
576+
deserialize.__safe_for_unpickling__ = True
577+
578+
579+
CRS.__UDT__ = CrsUDT()
580+
581+
543582
class TileExploder(JavaTransformer, DefaultParamsReadable, DefaultParamsWritable):
544583
"""
545584
Python wrapper for TileExploder.scala

0 commit comments

Comments
 (0)