Skip to content

Commit 82c1bb0

Browse files
authored
feat(c/sedona-libgpuspatial): Change the interface of SpatialRefiner (#717)
1 parent 4fcf9ed commit 82c1bb0

File tree

5 files changed

+51
-53
lines changed

5 files changed

+51
-53
lines changed

c/sedona-libgpuspatial/libgpuspatial/include/gpuspatial/gpuspatial_c.h

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -155,9 +155,13 @@ struct SedonaSpatialRefiner {
155155
/** Clear all built geometries from the refiner */
156156
int (*clear)(struct SedonaSpatialRefiner* self);
157157

158-
int (*init_schema)(struct SedonaSpatialRefiner* self,
159-
const struct ArrowSchema* build_schema,
160-
const struct ArrowSchema* probe_schema);
158+
/** Initialize the spatial refiner with the schema of the build geometries
159+
*
160+
* @param build_schema The Arrow schema of the build geometries;
161+
* @return 0 on success, non-zero on failure
162+
*/
163+
int (*init_build_schema)(struct SedonaSpatialRefiner* self,
164+
const struct ArrowSchema* build_schema);
161165

162166
/** Push geometries for building the spatial refiner
163167
*
@@ -177,6 +181,7 @@ struct SedonaSpatialRefiner {
177181
* Refine candidate pairs of geometries
178182
*
179183
* @param probe_array The Arrow array of the probe geometries
184+
* @param probe_schema The Arrow schema of the probe geometries
180185
* @param predicate The spatial relation predicate to evaluate
181186
* @param build_indices An array of build-side indices corresponding to candidate pairs.
182187
* This is a global index from 0 to N-1, where N is the total number of build geometries
@@ -188,7 +193,8 @@ struct SedonaSpatialRefiner {
188193
* @param new_indices_size Output parameter to store the number of refined pairs
189194
* @return 0 on success, non-zero on failure
190195
*/
191-
int (*refine)(struct SedonaSpatialRefiner* self, const struct ArrowArray* probe_array,
196+
int (*refine)(struct SedonaSpatialRefiner* self, const struct ArrowSchema* probe_schema,
197+
const struct ArrowArray* probe_array,
192198
enum SedonaSpatialRelationPredicate predicate, uint32_t* build_indices,
193199
uint32_t* probe_indices, uint32_t indices_size,
194200
uint32_t* new_indices_size);

c/sedona-libgpuspatial/libgpuspatial/src/gpuspatial_c.cc

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,6 @@ struct GpuSpatialRefinerExporter {
271271
struct Payload {
272272
std::unique_ptr<gpuspatial::SpatialRefiner> refiner;
273273
nanoarrow::UniqueArrayView build_array_view;
274-
nanoarrow::UniqueArrayView probe_array_view;
275274
runtime_data_t* rdata;
276275
};
277276
using private_data_t = GpuSpatialWrapper<Payload>;
@@ -293,25 +292,23 @@ struct GpuSpatialRefinerExporter {
293292
auto refiner = gpuspatial::CreateRTSpatialRefiner(refiner_config);
294293

295294
out->clear = &CClear;
296-
out->init_schema = &CInitSchema;
295+
out->init_build_schema = &CInitBuildSchema;
297296
out->push_build = &CPushBuild;
298297
out->finish_building = &CFinishBuilding;
299298
out->refine = &CRefine;
300299
out->get_last_error = &CGetLastError;
301300
out->release = &CRelease;
302-
out->private_data =
303-
new private_data_t{Payload{std::move(refiner), nanoarrow::UniqueArrayView(),
304-
nanoarrow::UniqueArrayView(), rdata},
305-
""};
301+
out->private_data = new private_data_t{
302+
Payload{std::move(refiner), nanoarrow::UniqueArrayView(), rdata}, ""};
306303
}
307304

308305
static int CClear(SedonaSpatialRefiner* self) {
309306
return SafeExecute(static_cast<private_data_t*>(self->private_data),
310307
[&] { use_refiner(self).Clear(); });
311308
}
312309

313-
static int CInitSchema(SedonaSpatialRefiner* self, const ArrowSchema* build_schema,
314-
const ArrowSchema* probe_schema) {
310+
static int CInitBuildSchema(SedonaSpatialRefiner* self,
311+
const ArrowSchema* build_schema) {
315312
return SafeExecute(static_cast<private_data_t*>(self->private_data), [&] {
316313
auto* private_data = static_cast<private_data_t*>(self->private_data);
317314
ArrowError arrow_error;
@@ -320,11 +317,6 @@ struct GpuSpatialRefinerExporter {
320317
throw std::runtime_error("ArrowArrayViewInitFromSchema error " +
321318
std::string(arrow_error.message));
322319
}
323-
if (ArrowArrayViewInitFromSchema(private_data->payload.probe_array_view.get(),
324-
probe_schema, &arrow_error) != NANOARROW_OK) {
325-
throw std::runtime_error("ArrowArrayViewInitFromSchema error " +
326-
std::string(arrow_error.message));
327-
}
328320
});
329321
}
330322

@@ -348,23 +340,29 @@ struct GpuSpatialRefinerExporter {
348340
[&] { use_refiner(self).FinishBuilding(); });
349341
}
350342

351-
static int CRefine(SedonaSpatialRefiner* self, const ArrowArray* probe_array,
343+
static int CRefine(SedonaSpatialRefiner* self, const ArrowSchema* probe_schema,
344+
const ArrowArray* probe_array,
352345
SedonaSpatialRelationPredicate predicate, uint32_t* build_indices,
353346
uint32_t* probe_indices, uint32_t indices_size,
354347
uint32_t* new_indices_size) {
355348
return SafeExecute(static_cast<private_data_t*>(self->private_data), [&] {
356-
auto* private_data = static_cast<private_data_t*>(self->private_data);
357-
auto* array_view = private_data->payload.build_array_view.get();
349+
// We need to create a local ArrayView to make sure this method is thread-safe
358350
ArrowError arrow_error;
359-
360-
if (ArrowArrayViewSetArray(array_view, probe_array, &arrow_error) != NANOARROW_OK) {
351+
nanoarrow::UniqueArrayView probe_array_view;
352+
if (ArrowArrayViewInitFromSchema(probe_array_view.get(), probe_schema,
353+
&arrow_error) != NANOARROW_OK) {
354+
throw std::runtime_error("ArrowArrayViewInitFromSchema error " +
355+
std::string(arrow_error.message));
356+
}
357+
if (ArrowArrayViewSetArray(probe_array_view.get(), probe_array, &arrow_error) !=
358+
NANOARROW_OK) {
361359
throw std::runtime_error("ArrowArrayViewSetArray error " +
362360
std::string(arrow_error.message));
363361
}
364362

365363
*new_indices_size = use_refiner(self).Refine(
366-
array_view, static_cast<gpuspatial::Predicate>(predicate), build_indices,
367-
probe_indices, indices_size);
364+
probe_array_view.get(), static_cast<gpuspatial::Predicate>(predicate),
365+
build_indices, probe_indices, indices_size);
368366
});
369367
}
370368

c/sedona-libgpuspatial/libgpuspatial/test/c_wrapper_test.cc

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ class CWrapperTest : public ::testing::Test {
122122

123123
runtime_config.ptx_root = ptx_root.c_str();
124124
runtime_config.device_id = 0;
125-
runtime_config.use_cuda_memory_pool = true;
125+
runtime_config.use_cuda_memory_pool = false;
126126
runtime_config.cuda_memory_pool_init_precent = 10;
127127
ASSERT_EQ(runtime_.init(&runtime_, &runtime_config), 0);
128128

@@ -242,8 +242,6 @@ TEST_F(CWrapperTest, InitializeJoiner) {
242242
NANOARROW_OK)
243243
<< error.message;
244244

245-
refiner_.init_schema(&refiner_, build_schema.get(), probe_schema.get());
246-
247245
for (int64_t j = 0; j < probe_array->length; j++) {
248246
ArrowBufferView wkb = ArrowArrayViewGetBytesUnsafe(probe_view.get(), j);
249247
auto geom = wkb_reader.read(wkb.data.as_uint8, wkb.size_bytes);
@@ -282,13 +280,17 @@ TEST_F(CWrapperTest, InitializeJoiner) {
282280
},
283281
&intersection_ids);
284282

283+
if (i == 0) {
284+
ASSERT_EQ(refiner_.init_build_schema(&refiner_, build_schema.get()), 0);
285+
}
286+
285287
refiner_.clear(&refiner_);
286288
ASSERT_EQ(refiner_.push_build(&refiner_, build_array.get()), 0);
287289
ASSERT_EQ(refiner_.finish_building(&refiner_), 0);
288290

289291
uint32_t new_len;
290292
ASSERT_EQ(refiner_.refine(
291-
&refiner_, probe_array.get(),
293+
&refiner_, probe_schema.get(), probe_array.get(),
292294
SedonaSpatialRelationPredicate::SedonaSpatialPredicateContains,
293295
intersection_ids.build_indices_ptr, intersection_ids.probe_indices_ptr,
294296
intersection_ids.length, &new_len),

c/sedona-libgpuspatial/src/lib.rs

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -151,10 +151,9 @@ mod sys {
151151
Ok(Self { inner })
152152
}
153153

154-
/// Initializes the schema for the refiner based on the data types of the build and probe geometries.
155-
/// This allows the refiner to understand how to interpret the geometry data for both sets.
156-
pub fn init_schema(&mut self, build: &DataType, probe: &DataType) -> Result<()> {
157-
self.inner.init_schema(build, probe)
154+
/// Initializes the schema for the refiner based on the data types of the build geometries.
155+
pub fn init_build_schema(&mut self, build: &DataType) -> Result<()> {
156+
self.inner.init_build_schema(build)
158157
}
159158

160159
/// Clears any previously inserted data from the refiner, allowing it to be reused for building a new set of geometries.
@@ -207,7 +206,7 @@ mod sys {
207206
pub fn push_build(&mut self, _r: &[Rect<f32>]) -> Result<()> {
208207
Err(GpuSpatialError::GpuNotAvailable)
209208
}
210-
pub fn finish_building(self) -> Result<GpuSpatialIndex> {
209+
pub fn finish_building(&mut self) -> Result<GpuSpatialIndex> {
211210
Err(GpuSpatialError::GpuNotAvailable)
212211
}
213212
pub fn probe(&self, _r: &[Rect<f32>]) -> Result<(Vec<u32>, Vec<u32>)> {
@@ -219,9 +218,10 @@ mod sys {
219218
pub fn try_new(_opts: &GpuSpatialOptions) -> Result<Self> {
220219
Err(GpuSpatialError::GpuNotAvailable)
221220
}
222-
pub fn init_schema(&mut self, _b: &DataType, _p: &DataType) -> Result<()> {
221+
pub fn init_build_schema(&mut self, _b: &DataType) -> Result<()> {
223222
Err(GpuSpatialError::GpuNotAvailable)
224223
}
224+
225225
pub fn clear(&mut self) {}
226226
pub fn push_build(&mut self, _arr: &arrow_array::ArrayRef) -> Result<()> {
227227
Err(GpuSpatialError::GpuNotAvailable)
@@ -319,9 +319,7 @@ mod tests {
319319
let points = create_array_storage(point_values, &WKB_GEOMETRY);
320320

321321
// 2. Build Refiner
322-
refiner
323-
.init_schema(polygons.data_type(), points.data_type())
324-
.unwrap();
322+
refiner.init_build_schema(polygons.data_type()).unwrap();
325323

326324
refiner.push_build(&polygons).unwrap();
327325

c/sedona-libgpuspatial/src/libgpuspatial.rs

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ impl Refiner {
346346
) -> Result<Self, GpuSpatialError> {
347347
let mut refiner = SedonaSpatialRefiner {
348348
clear: None,
349-
init_schema: None,
349+
init_build_schema: None,
350350
push_build: None,
351351
finish_building: None,
352352
refine: None,
@@ -377,26 +377,19 @@ impl Refiner {
377377
/// Initializes the schema for the refiner using the provided build and probe data types.
378378
/// It converts the Arrow `DataType` to the C-compatible `FFI_ArrowSchema` and calls the underlying C function.
379379
/// If initialization fails, it retrieves the error message from the C struct and returns a `GpuSpatialError`.
380-
pub fn init_schema(
381-
&mut self,
382-
build_dt: &DataType,
383-
probe_dt: &DataType,
384-
) -> Result<(), GpuSpatialError> {
380+
pub fn init_build_schema(&mut self, build_dt: &DataType) -> Result<(), GpuSpatialError> {
385381
let build_ffi = FFI_ArrowSchema::try_from(build_dt)?;
386-
let probe_ffi = FFI_ArrowSchema::try_from(probe_dt)?;
387-
let init_fn = self.inner.refiner.init_schema.unwrap();
382+
let init_fn = self
383+
.inner
384+
.refiner
385+
.init_build_schema
386+
.expect("init_build_schema function is None");
388387
let get_last_error = self.inner.refiner.get_last_error;
389388
let refiner_ptr = &mut self.inner.refiner as *mut SedonaSpatialRefiner;
390389

391390
unsafe {
392391
check_ffi_call(
393-
|| {
394-
init_fn(
395-
&mut self.inner.refiner,
396-
&build_ffi as *const _ as *const _,
397-
&probe_ffi as *const _ as *const _,
398-
)
399-
},
392+
|| init_fn(&mut self.inner.refiner, &build_ffi as *const _ as *const _),
400393
get_last_error,
401394
refiner_ptr,
402395
GpuSpatialError::Init,
@@ -464,7 +457,7 @@ impl Refiner {
464457
build_indices: &mut Vec<u32>,
465458
probe_indices: &mut Vec<u32>,
466459
) -> Result<(), GpuSpatialError> {
467-
let (ffi_array, _) = arrow_array::ffi::to_ffi(&array.to_data())?;
460+
let (ffi_array, ffi_schema) = arrow_array::ffi::to_ffi(&array.to_data())?;
468461
let refine_fn = self.inner.refiner.refine.unwrap();
469462
let mut new_len: u32 = 0;
470463

@@ -473,6 +466,7 @@ impl Refiner {
473466
|| {
474467
refine_fn(
475468
&self.inner.refiner as *const _ as *mut _,
469+
&ffi_schema as *const _ as *mut _,
476470
&ffi_array as *const _ as *mut _,
477471
predicate.as_c_uint(),
478472
build_indices.as_mut_ptr(),

0 commit comments

Comments
 (0)