Skip to content

Commit 1a74338

Browse files
committed
Remove STACDataFrame searchLimit parameter
1 parent fdd46cb commit 1a74338

File tree

8 files changed

+63
-51
lines changed

8 files changed

+63
-51
lines changed

datasource/src/main/scala/org/locationtech/rasterframes/datasource/stac/api/StacApiDataSource.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,11 @@ class StacApiDataSource extends TableProvider with DataSourceRegister {
1616
def getTable(structType: StructType, transforms: Array[Transform], map: util.Map[String, String]): Table =
1717
new StacApiTable()
1818

19-
override def shortName(): String = "stac-api"
19+
def shortName(): String = StacApiDataSource.SHORT_NAME
2020
}
2121

2222
object StacApiDataSource {
2323
final val SHORT_NAME = "stac-api"
2424
final val URI_PARAM = "uri"
2525
final val SEARCH_FILTERS_PARAM = "search-filters"
26-
final val SEARCH_LIMIT_PARAM = "search-limit"
2726
}

datasource/src/main/scala/org/locationtech/rasterframes/datasource/stac/api/StacApiPartition.scala

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,12 @@ import com.azavea.stac4s.StacItem
77
import geotrellis.store.util.BlockingThreadPool
88
import sttp.client3.asynchttpclient.cats.AsyncHttpClientCatsBackend
99
import com.azavea.stac4s.api.client._
10-
import eu.timepit.refined.types.numeric.NonNegInt
1110
import cats.effect.IO
1211
import sttp.model.Uri
1312
import org.apache.spark.sql.catalyst.InternalRow
1413
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
1514

16-
case class StacApiPartition(uri: Uri, searchFilters: SearchFilters, searchLimit: Option[NonNegInt]) extends InputPartition
15+
case class StacApiPartition(uri: Uri, searchFilters: SearchFilters) extends InputPartition
1716

1817
class StacApiPartitionReaderFactory extends PartitionReaderFactory {
1918
override def createReader(partition: InputPartition): PartitionReader[InternalRow] = {
@@ -25,24 +24,17 @@ class StacApiPartitionReaderFactory extends PartitionReaderFactory {
2524
}
2625

2726
class StacApiPartitionReader(partition: StacApiPartition) extends PartitionReader[InternalRow] {
28-
lazy val partitionValues: Iterator[StacItem] = {
29-
implicit val cs = IO.contextShift(BlockingThreadPool.executionContext)
30-
AsyncHttpClientCatsBackend
31-
.resource[IO]()
32-
.use { backend =>
33-
SttpStacClient(backend, partition.uri)
34-
.search(partition.searchFilters)
35-
.take(partition.searchLimit.map(_.value))
36-
.compile
37-
.toList
38-
}
39-
.map(_.toIterator)
40-
.unsafeRunSync()
41-
}
27+
28+
@transient private implicit lazy val cs = IO.contextShift(BlockingThreadPool.executionContext)
29+
@transient private lazy val backend = AsyncHttpClientCatsBackend[IO]().unsafeRunSync()
30+
@transient private lazy val partitionValues: Iterator[StacItem] =
31+
SttpStacClient(backend, partition.uri)
32+
.search(partition.searchFilters)
33+
.toIterator(_.unsafeRunSync())
4234

4335
def next: Boolean = partitionValues.hasNext
4436

4537
def get: InternalRow = partitionValues.next.toInternalRow
4638

47-
def close(): Unit = { }
39+
def close(): Unit = backend.close().unsafeRunSync()
4840
}

datasource/src/main/scala/org/locationtech/rasterframes/datasource/stac/api/StacApiScanBuilder.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@ import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionRead
88
import org.apache.spark.sql.types.StructType
99
import sttp.model.Uri
1010

11-
class StacApiScanBuilder(uri: Uri, searchFilters: SearchFilters, searchLimit: Option[NonNegInt]) extends ScanBuilder {
12-
override def build(): Scan = new StacApiBatchScan(uri, searchFilters, searchLimit)
11+
class StacApiScanBuilder(uri: Uri, searchFilters: SearchFilters) extends ScanBuilder {
12+
def build(): Scan = new StacApiBatchScan(uri, searchFilters)
1313
}
1414

1515
/** Batch Reading Support. The schema is repeated here as it can change after column pruning, etc. */
16-
class StacApiBatchScan(uri: Uri, searchFilters: SearchFilters, searchLimit: Option[NonNegInt]) extends Scan with Batch {
16+
class StacApiBatchScan(uri: Uri, searchFilters: SearchFilters) extends Scan with Batch {
1717
def readSchema(): StructType = stacItemEncoder.schema
1818

1919
override def toBatch: Batch = this
@@ -23,6 +23,6 @@ class StacApiBatchScan(uri: Uri, searchFilters: SearchFilters, searchLimit: Opti
2323
* To perform a distributed load, we'd need to know some internals about how the next page token is computed.
2424
* This can be a good idea for the STAC Spec extension.
2525
* */
26-
def planInputPartitions(): Array[InputPartition] = Array(StacApiPartition(uri, searchFilters, searchLimit))
26+
def planInputPartitions(): Array[InputPartition] = Array(StacApiPartition(uri, searchFilters))
2727
def createReaderFactory(): PartitionReaderFactory = new StacApiPartitionReaderFactory()
2828
}

datasource/src/main/scala/org/locationtech/rasterframes/datasource/stac/api/StacApiTable.scala

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapabil
77
import org.apache.spark.sql.connector.read.ScanBuilder
88
import org.apache.spark.sql.types.StructType
99
import org.apache.spark.sql.util.CaseInsensitiveStringMap
10-
import org.locationtech.rasterframes.datasource.stac.api.StacApiDataSource.{SEARCH_LIMIT_PARAM, SEARCH_FILTERS_PARAM, URI_PARAM}
11-
import org.locationtech.rasterframes.datasource.{intParam, jsonParam, uriParam}
10+
import org.locationtech.rasterframes.datasource.stac.api.StacApiDataSource.{SEARCH_FILTERS_PARAM, URI_PARAM}
11+
import org.locationtech.rasterframes.datasource.{jsonParam, uriParam}
1212
import sttp.model.Uri
1313

1414
import scala.collection.JavaConverters._
@@ -24,7 +24,7 @@ class StacApiTable extends Table with SupportsRead {
2424
def capabilities(): util.Set[TableCapability] = Set(TableCapability.BATCH_READ).asJava
2525

2626
def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder =
27-
new StacApiScanBuilder(options.uri, options.searchFilters, options.searchLimit)
27+
new StacApiScanBuilder(options.uri, options.searchFilters)
2828
}
2929

3030
object StacApiTable {
@@ -35,7 +35,5 @@ object StacApiTable {
3535
jsonParam(SEARCH_FILTERS_PARAM, options)
3636
.flatMap(_.as[SearchFilters].toOption)
3737
.getOrElse(SearchFilters(limit = NonNegInt.from(30).toOption))
38-
39-
def searchLimit: Option[NonNegInt] = intParam(SEARCH_LIMIT_PARAM, options).flatMap(NonNegInt.from(_).toOption)
4038
}
4139
}

datasource/src/main/scala/org/locationtech/rasterframes/datasource/stac/api/package.scala

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
package org.locationtech.rasterframes.datasource.stac
22

3+
import cats.Monad
4+
import cats.syntax.functor._
35
import com.azavea.stac4s.api.client.SearchFilters
46
import org.apache.spark.sql.{DataFrame, DataFrameReader}
57
import io.circe.syntax._
6-
import fs2.Stream
8+
import fs2.{Pull, Stream}
79
import shapeless.tag
810
import shapeless.tag.@@
911
import org.apache.spark.sql.SparkSession
@@ -17,6 +19,7 @@ package object api {
1719

1820
implicit class StacApiDataFrameReaderOps(val reader: StacApiDataFrameReader) extends AnyVal {
1921
def loadStac: StacApiDataFrame = tag[StacApiDataFrameTag][DataFrame](reader.load)
22+
def loadStac(limit: Int): StacApiDataFrame = tag[StacApiDataFrameTag][DataFrame](reader.load.limit(limit))
2023
}
2124

2225
implicit class StacApiDataFrameOps(val df: StacApiDataFrame) extends AnyVal {
@@ -38,7 +41,27 @@ package object api {
3841
}
3942

4043
implicit class Fs2StreamOps[F[_], T](val self: Stream[F, T]) {
41-
def take(n: Option[Int]): Stream[F, T] = n.fold(self)(self.take(_))
44+
/** Unsafe API to interop with the Spark API. */
45+
def toIterator(run: F[Option[(T, fs2.Stream[F, T])]] => Option[(T, fs2.Stream[F, T])])
46+
(implicit monad: Monad[F], compiler: Stream.Compiler[F, F]): Iterator[T] = new Iterator[T] {
47+
private var head = self
48+
private def nextF: F[Option[(T, fs2.Stream[F, T])]] =
49+
head
50+
.pull.uncons1
51+
.flatMap(Pull.output1)
52+
.stream
53+
.compile
54+
.last
55+
.map(_.flatten)
56+
57+
def hasNext(): Boolean = run(nextF).nonEmpty
58+
59+
def next(): T = {
60+
val (item, tail) = run(nextF).get
61+
this.head = tail
62+
item
63+
}
64+
}
4265
}
4366

4467
implicit class DataFrameReaderOps(val self: DataFrameReader) extends AnyVal {
@@ -48,12 +71,11 @@ package object api {
4871

4972
implicit class DataFrameReaderStacApiOps(val reader: DataFrameReader) extends AnyVal {
5073
def stacApi(): StacApiDataFrameReader = tag[StacApiDataFrameTag][DataFrameReader](reader.format(StacApiDataSource.SHORT_NAME))
51-
def stacApi(uri: String, filters: SearchFilters = SearchFilters(), searchLimit: Option[Int] = None): StacApiDataFrameReader =
74+
def stacApi(uri: String, filters: SearchFilters = SearchFilters()): StacApiDataFrameReader =
5275
tag[StacApiDataFrameTag][DataFrameReader](
5376
stacApi()
5477
.option(StacApiDataSource.URI_PARAM, uri)
5578
.option(StacApiDataSource.SEARCH_FILTERS_PARAM, filters.asJson.noSpaces)
56-
.option(StacApiDataSource.SEARCH_LIMIT_PARAM, searchLimit)
5779
)
5880
}
5981
}

datasource/src/test/scala/org/locationtech/rasterframes/datasource/stac/api/StacApiDataSourceTest.scala

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,7 @@ import org.locationtech.rasterframes.datasource.raster._
2525
import org.locationtech.rasterframes.datasource.stac.api.encoders._
2626
import com.azavea.stac4s.StacItem
2727
import com.azavea.stac4s.api.client.{SearchFilters, SttpStacClient}
28-
import cats.syntax.option._
2928
import cats.effect.IO
30-
import eu.timepit.refined.auto._
3129
import geotrellis.store.util.BlockingThreadPool
3230
import org.apache.spark.sql.functions.explode
3331
import org.locationtech.rasterframes.TestEnvironment
@@ -45,9 +43,10 @@ class StacApiDataSourceTest extends TestEnvironment { self =>
4543
.read
4644
.stacApi(
4745
"https://franklin.nasa-hsi.azavea.com/",
48-
filters = SearchFilters(items = List("aviris-l1-cogs_f130329t01p00r06_sc01")),
49-
searchLimit = Some(1)
50-
).load
46+
filters = SearchFilters(items = List("aviris-l1-cogs_f130329t01p00r06_sc01"))
47+
)
48+
.load
49+
.limit(1)
5150

5251
results.rdd.partitions.length shouldBe 1
5352
results.count() shouldBe 1L
@@ -78,9 +77,10 @@ class StacApiDataSourceTest extends TestEnvironment { self =>
7877
.read
7978
.stacApi(
8079
"https://franklin.nasa-hsi.azavea.com/",
81-
filters = SearchFilters(items = List("aviris-l1-cogs_f130329t01p00r06_sc01")),
82-
searchLimit = Some(1)
83-
).load
80+
filters = SearchFilters(items = List("aviris-l1-cogs_f130329t01p00r06_sc01"))
81+
)
82+
.load
83+
.limit(1)
8484

8585
results.rdd.partitions.length shouldBe 1
8686

@@ -118,10 +118,9 @@ class StacApiDataSourceTest extends TestEnvironment { self =>
118118
.read
119119
.stacApi(
120120
"https://franklin.nasa-hsi.azavea.com/",
121-
filters = SearchFilters(items = List("aviris-l1-cogs_f130329t01p00r06_sc01")),
122-
searchLimit = Some(1)
121+
filters = SearchFilters(items = List("aviris-l1-cogs_f130329t01p00r06_sc01"))
123122
)
124-
.loadStac
123+
.loadStac(limit = 1) // to preserve the STAC DataFrame type
125124

126125
val assets =
127126
items
@@ -149,7 +148,7 @@ class StacApiDataSourceTest extends TestEnvironment { self =>
149148
it("should read from Astraea Earth service") {
150149
import spark.implicits._
151150

152-
val results = spark.read.stacApi("https://eod-catalog-svc-prod.astraea.earth/", searchLimit = Some(1)).load
151+
val results = spark.read.stacApi("https://eod-catalog-svc-prod.astraea.earth/").load.limit(1)
153152

154153
// results.printSchema()
155154

@@ -178,8 +177,9 @@ class StacApiDataSourceTest extends TestEnvironment { self =>
178177
val items =
179178
spark
180179
.read
181-
.stacApi("https://eod-catalog-svc-prod.astraea.earth/", searchLimit = 1.some)
180+
.stacApi("https://eod-catalog-svc-prod.astraea.earth/")
182181
.load
182+
.limit(1)
183183

184184
println(items.collect().toList.length)
185185

@@ -199,7 +199,11 @@ class StacApiDataSourceTest extends TestEnvironment { self =>
199199

200200
ignore("should fetch rasters from the Datacube STAC API service") {
201201
import spark.implicits._
202-
val items = spark.read.stacApi("https://datacube.services.geo.ca/api", filters = SearchFilters(collections=List("markham")), searchLimit = Some(1)).load
202+
val items = spark
203+
.read
204+
.stacApi("https://datacube.services.geo.ca/api", filters = SearchFilters(collections=List("markham")))
205+
.load
206+
.limit(1)
203207

204208
println(items.collect().toList.length)
205209

pyrasterframes/src/main/python/pyrasterframes/__init__.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -255,20 +255,17 @@ def temp_name():
255255
def _stac_api_reader(
256256
df_reader: DataFrameReader,
257257
uri: str,
258-
filters: dict = None,
259-
search_limit: Optional[int] = None) -> DataFrame:
258+
filters: dict = None) -> DataFrame:
260259
"""
261260
uri - STAC API uri
262261
filters - a STAC API Search filters dict (bbox, datetime, intersects, collections, items, limit, query, next)
263-
search_limit - search results convenient limit method
264262
"""
265263
import json
266264

267265
return df_reader \
268266
.format("stac-api") \
269267
.option("uri", uri) \
270268
.option("search-filters", json.dumps(filters)) \
271-
.option("search-limit", search_limit) \
272269
.load()
273270

274271
def _geotiff_writer(

rf-notebook/src/main/notebooks/STAC API Example.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@
7575
"# due to the collection size and query parameters\n",
7676
"# it makes sense to limit the amount of items retrieved from the STAC API\n",
7777
"uri = 'https://earth-search.aws.element84.com/v0'\n",
78-
"df = spark.read.stacapi(uri, {'collections': ['landsat-8-l1-c1']}, search_limit=100)"
78+
"df = spark.read.stacapi(uri, {'collections': ['landsat-8-l1-c1']}).limit(100)"
7979
]
8080
},
8181
{

0 commit comments

Comments
 (0)