|
| 1 | +/* |
| 2 | + * This software is licensed under the Apache 2 license, quoted below. |
| 3 | + * |
| 4 | + * Copyright 2019 Astraea, Inc. |
| 5 | + * |
| 6 | + * Licensed under the Apache License, Version 2.0 (the "License"); you may not |
| 7 | + * use this file except in compliance with the License. You may obtain a copy of |
| 8 | + * the License at |
| 9 | + * |
| 10 | + * [http://www.apache.org/licenses/LICENSE-2.0] |
| 11 | + * |
| 12 | + * Unless required by applicable law or agreed to in writing, software |
| 13 | + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |
| 14 | + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |
| 15 | + * License for the specific language governing permissions and limitations under |
| 16 | + * the License. |
| 17 | + * |
| 18 | + * SPDX-License-Identifier: Apache-2.0 |
| 19 | + * |
| 20 | + */ |
| 21 | + |
| 22 | +package org.locationtech.rasterframes.expressions.localops |
| 23 | + |
| 24 | +import geotrellis.raster.Tile |
| 25 | +import geotrellis.raster.mapalgebra.local.IfCell |
| 26 | +import org.apache.spark.sql.Column |
| 27 | +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult |
| 28 | +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} |
| 29 | +import org.apache.spark.sql.types.{ArrayType, DataType} |
| 30 | +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback |
| 31 | +import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Expression, ExpressionDescription} |
| 32 | +import org.apache.spark.sql.catalyst.util.ArrayData |
| 33 | +import org.apache.spark.sql.rf.TileUDT |
| 34 | +import org.locationtech.rasterframes.encoders.CatalystSerializer._ |
| 35 | +import org.locationtech.rasterframes.expressions.DynamicExtractors._ |
| 36 | +import org.locationtech.rasterframes.expressions._ |
| 37 | + |
| 38 | +@ExpressionDescription( |
| 39 | + usage = "_FUNC_(tile, rhs) - In each cell of `tile`, return true if the value is in rhs.", |
| 40 | + arguments = """ |
| 41 | + Arguments: |
| 42 | + * tile - tile column to apply abs |
| 43 | + * rhs - array to test against |
| 44 | + """, |
| 45 | + examples = """ |
| 46 | + Examples: |
| 47 | + > SELECT _FUNC_(tile, array); |
| 48 | + ...""" |
| 49 | +) |
| 50 | +case class IsIn(left: Expression, right: Expression) extends BinaryExpression with CodegenFallback { |
| 51 | + override val nodeName: String = "rf_local_is_in" |
| 52 | + |
| 53 | + override def dataType: DataType = left.dataType |
| 54 | + |
| 55 | + @transient private lazy val elementType: DataType = left.dataType.asInstanceOf[ArrayType].elementType |
| 56 | + |
| 57 | + override def checkInputDataTypes(): TypeCheckResult = |
| 58 | + if(!tileExtractor.isDefinedAt(left.dataType)) { |
| 59 | + TypeCheckFailure(s"Input type '${left.dataType}' does not conform to a raster type.") |
| 60 | + } else right.dataType match { |
| 61 | + case _: ArrayType ⇒ TypeCheckSuccess |
| 62 | + case _ ⇒ TypeCheckFailure(s"Input type '${right.dataType}' does not conform to ArrayType.") |
| 63 | + } |
| 64 | + |
| 65 | + override protected def nullSafeEval(input1: Any, input2: Any): Any = { |
| 66 | + implicit val tileSer = TileUDT.tileSerializer |
| 67 | + val (childTile, childCtx) = tileExtractor(left.dataType)(row(input1)) |
| 68 | + |
| 69 | + val arr = input2.asInstanceOf[ArrayData].toArray[AnyRef](elementType) |
| 70 | + |
| 71 | + childCtx match { |
| 72 | + case Some(ctx) => ctx.toProjectRasterTile(op(childTile, arr)).toInternalRow |
| 73 | + case None => op(childTile, arr).toInternalRow |
| 74 | + } |
| 75 | + |
| 76 | + } |
| 77 | + |
| 78 | + protected def op(left: Tile, right: IndexedSeq[AnyRef]): Tile = { |
| 79 | + def fn(i: Int): Boolean = right.contains(i) |
| 80 | + IfCell(left, fn(_), 1, 0) |
| 81 | + } |
| 82 | + |
| 83 | +} |
| 84 | + |
| 85 | +object IsIn { |
| 86 | + def apply(left: Column, right: Column): Column = |
| 87 | + new Column(IsIn(left.expr, right.expr)) |
| 88 | +} |
0 commit comments