Skip to content

Commit f8a0dbc

Browse files
yew1ebcxzl25
andcommitted
[AURON1739] Support limit with offset
--------- Co-authored-by: cxzl25 <[email protected]>
1 parent 9f15314 commit f8a0dbc

File tree

16 files changed

+310
-89
lines changed

16 files changed

+310
-89
lines changed

native-engine/auron-serde/proto/auron.proto

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -631,7 +631,8 @@ message SortExecNode {
631631

632632
message FetchLimit {
633633
// wrap into a message to make it optional
634-
uint64 limit = 1;
634+
uint32 limit = 1;
635+
uint32 offset = 2;
635636
}
636637

637638
message PhysicalRepartition {
@@ -705,7 +706,8 @@ enum AggMode {
705706

706707
message LimitExecNode {
707708
PhysicalPlanNode input = 1;
708-
uint64 limit = 2;
709+
uint32 limit = 2;
710+
uint32 offset = 3;
709711
}
710712

711713
message FFIReaderExecNode {

native-engine/auron-serde/src/from_proto.rs

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -315,12 +315,18 @@ impl TryInto<Arc<dyn ExecutionPlan>> for &protobuf::PhysicalPlanNode {
315315
panic!("Failed to parse physical sort expressions: {}", e);
316316
});
317317

318+
let fetch = sort.fetch_limit.as_ref();
319+
let limit_for_sort = fetch.map(|f| f.limit as usize);
320+
let offset = fetch.map(|f| f.offset as usize).unwrap_or(0);
321+
let mut plan: Arc<dyn ExecutionPlan> =
322+
Arc::new(SortExec::new(input, exprs, limit_for_sort));
323+
324+
if offset > 0 {
325+
plan = Arc::new(LimitExec::new(plan, usize::MAX, offset));
326+
}
327+
318328
// always preserve partitioning
319-
Ok(Arc::new(SortExec::new(
320-
input,
321-
exprs,
322-
sort.fetch_limit.as_ref().map(|limit| limit.limit as usize),
323-
)))
329+
Ok(plan)
324330
}
325331
PhysicalPlanType::BroadcastJoinBuildHashMap(bhm) => {
326332
let input: Arc<dyn ExecutionPlan> = convert_box_required!(bhm.input)?;
@@ -501,7 +507,11 @@ impl TryInto<Arc<dyn ExecutionPlan>> for &protobuf::PhysicalPlanNode {
501507
}
502508
PhysicalPlanType::Limit(limit) => {
503509
let input: Arc<dyn ExecutionPlan> = convert_box_required!(limit.input)?;
504-
Ok(Arc::new(LimitExec::new(input, limit.limit)))
510+
Ok(Arc::new(LimitExec::new(
511+
input,
512+
limit.limit as usize,
513+
limit.offset as usize,
514+
)))
505515
}
506516
PhysicalPlanType::FfiReader(ffi_reader) => {
507517
let schema = Arc::new(convert_required!(ffi_reader.schema)?);
@@ -513,7 +523,11 @@ impl TryInto<Arc<dyn ExecutionPlan>> for &protobuf::PhysicalPlanNode {
513523
}
514524
PhysicalPlanType::CoalesceBatches(coalesce_batches) => {
515525
let input: Arc<dyn ExecutionPlan> = convert_box_required!(coalesce_batches.input)?;
516-
Ok(Arc::new(LimitExec::new(input, coalesce_batches.batch_size)))
526+
Ok(Arc::new(LimitExec::new(
527+
input,
528+
coalesce_batches.batch_size as usize,
529+
0,
530+
)))
517531
}
518532
PhysicalPlanType::Expand(expand) => {
519533
let schema = Arc::new(convert_required!(expand.schema)?);

native-engine/datafusion-ext-plans/src/limit_exec.rs

Lines changed: 87 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -41,16 +41,18 @@ use crate::common::execution_context::ExecutionContext;
4141
#[derive(Debug)]
4242
pub struct LimitExec {
4343
input: Arc<dyn ExecutionPlan>,
44-
limit: u64,
44+
limit: usize,
45+
skip: usize,
4546
pub metrics: ExecutionPlanMetricsSet,
4647
props: OnceCell<PlanProperties>,
4748
}
4849

4950
impl LimitExec {
50-
pub fn new(input: Arc<dyn ExecutionPlan>, limit: u64) -> Self {
51+
pub fn new(input: Arc<dyn ExecutionPlan>, limit: usize, skip: usize) -> Self {
5152
Self {
5253
input,
5354
limit,
55+
skip,
5456
metrics: ExecutionPlanMetricsSet::new(),
5557
props: OnceCell::new(),
5658
}
@@ -59,7 +61,7 @@ impl LimitExec {
5961

6062
impl DisplayAs for LimitExec {
6163
fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
62-
write!(f, "LimitExec(limit={})", self.limit)
64+
write!(f, "LimitExec(limit={},skip={})", self.limit, self.skip)
6365
}
6466
}
6567

@@ -95,7 +97,11 @@ impl ExecutionPlan for LimitExec {
9597
self: Arc<Self>,
9698
children: Vec<Arc<dyn ExecutionPlan>>,
9799
) -> Result<Arc<dyn ExecutionPlan>> {
98-
Ok(Arc::new(Self::new(children[0].clone(), self.limit)))
100+
Ok(Arc::new(Self::new(
101+
children[0].clone(),
102+
self.limit,
103+
self.skip,
104+
)))
99105
}
100106

101107
fn execute(
@@ -105,23 +111,27 @@ impl ExecutionPlan for LimitExec {
105111
) -> Result<SendableRecordBatchStream> {
106112
let exec_ctx = ExecutionContext::new(context, partition, self.schema(), &self.metrics);
107113
let input = exec_ctx.execute_with_input_stats(&self.input)?;
108-
execute_limit(input, self.limit, exec_ctx)
114+
if self.skip == 0 {
115+
execute_limit(input, self.limit, exec_ctx)
116+
} else {
117+
execute_limit_with_skip(input, self.limit, self.skip, exec_ctx)
118+
}
109119
}
110120

111121
fn statistics(&self) -> Result<Statistics> {
112122
Statistics::with_fetch(
113123
self.input.statistics()?,
114124
self.schema(),
115-
Some(self.limit as usize),
116-
0,
125+
Some(self.limit),
126+
self.skip,
117127
1,
118128
)
119129
}
120130
}
121131

122132
fn execute_limit(
123133
mut input: SendableRecordBatchStream,
124-
limit: u64,
134+
limit: usize,
125135
exec_ctx: Arc<ExecutionContext>,
126136
) -> Result<SendableRecordBatchStream> {
127137
Ok(exec_ctx
@@ -131,11 +141,49 @@ fn execute_limit(
131141
while remaining > 0
132142
&& let Some(mut batch) = input.next().await.transpose()?
133143
{
134-
if remaining < batch.num_rows() as u64 {
135-
batch = batch.slice(0, remaining as usize);
144+
if remaining < batch.num_rows() {
145+
batch = batch.slice(0, remaining);
146+
remaining = 0;
147+
} else {
148+
remaining -= batch.num_rows();
149+
}
150+
exec_ctx.baseline_metrics().record_output(batch.num_rows());
151+
sender.send(batch).await;
152+
}
153+
Ok(())
154+
}))
155+
}
156+
157+
fn execute_limit_with_skip(
158+
mut input: SendableRecordBatchStream,
159+
limit: usize,
160+
offset: usize,
161+
exec_ctx: Arc<ExecutionContext>,
162+
) -> Result<SendableRecordBatchStream> {
163+
Ok(exec_ctx
164+
.clone()
165+
.output_with_sender("Limit", move |sender| async move {
166+
let mut skip = offset;
167+
let mut remaining = limit - skip;
168+
while remaining > 0
169+
&& let Some(mut batch) = input.next().await.transpose()?
170+
{
171+
if skip > 0 {
172+
let rows = batch.num_rows();
173+
if skip >= rows {
174+
skip -= rows;
175+
continue;
176+
}
177+
178+
batch = batch.slice(skip, rows - skip);
179+
skip = 0;
180+
}
181+
182+
if remaining < batch.num_rows() {
183+
batch = batch.slice(0, remaining);
136184
remaining = 0;
137185
} else {
138-
remaining -= batch.num_rows() as u64;
186+
remaining -= batch.num_rows();
139187
}
140188
exec_ctx.baseline_metrics().record_output(batch.num_rows());
141189
sender.send(batch).await;
@@ -203,7 +251,7 @@ mod test {
203251
("b", &vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
204252
("c", &vec![5, 6, 7, 8, 9, 0, 1, 2, 3, 4]),
205253
);
206-
let limit_exec = LimitExec::new(input, 2_u64);
254+
let limit_exec = LimitExec::new(input, 2, 0);
207255
let session_ctx = SessionContext::new();
208256
let task_ctx = session_ctx.task_ctx();
209257
let output = limit_exec.execute(0, task_ctx).unwrap();
@@ -222,4 +270,31 @@ mod test {
222270
assert_eq!(row_count, Precision::Exact(2));
223271
Ok(())
224272
}
273+
274+
#[tokio::test]
275+
async fn test_limit_with_skip() -> Result<()> {
276+
let input = build_table(
277+
("a", &vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
278+
("b", &vec![9, 8, 7, 6, 5, 4, 3, 2, 1, 0]),
279+
("c", &vec![5, 6, 7, 8, 9, 0, 1, 2, 3, 4]),
280+
);
281+
let limit_exec = LimitExec::new(input, 7, 5);
282+
let session_ctx = SessionContext::new();
283+
let task_ctx = session_ctx.task_ctx();
284+
let output = limit_exec.execute(0, task_ctx).unwrap();
285+
let batches = common::collect(output).await?;
286+
let row_count: usize = batches.iter().map(|batch| batch.num_rows()).sum();
287+
288+
let expected = vec![
289+
"+---+---+---+",
290+
"| a | b | c |",
291+
"+---+---+---+",
292+
"| 5 | 4 | 0 |",
293+
"| 6 | 3 | 1 |",
294+
"+---+---+---+",
295+
];
296+
assert_batches_eq!(expected, &batches);
297+
assert_eq!(row_count, 2);
298+
Ok(())
299+
}
225300
}

spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/auron/ShimsImpl.scala

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.First
5050
import org.apache.spark.sql.catalyst.plans.JoinType
5151
import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode
5252
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
53-
import org.apache.spark.sql.execution.CoalescedPartitionSpec
54-
import org.apache.spark.sql.execution.FileSourceScanExec
55-
import org.apache.spark.sql.execution.PartialMapperPartitionSpec
56-
import org.apache.spark.sql.execution.PartialReducerPartitionSpec
57-
import org.apache.spark.sql.execution.ShuffledRowRDD
58-
import org.apache.spark.sql.execution.ShufflePartitionSpec
59-
import org.apache.spark.sql.execution.SparkPlan
60-
import org.apache.spark.sql.execution.UnaryExecNode
53+
import org.apache.spark.sql.execution._
6154
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec, QueryStageExec, ShuffleQueryStageExec}
6255
import org.apache.spark.sql.execution.auron.plan._
6356
import org.apache.spark.sql.execution.auron.plan.ConvertToNativeExec
@@ -288,16 +281,38 @@ class ShimsImpl extends Shims with Logging {
288281
child: SparkPlan): NativeGenerateBase =
289282
NativeGenerateExec(generator, requiredChildOutput, outer, generatorOutput, child)
290283

291-
override def createNativeGlobalLimitExec(limit: Long, child: SparkPlan): NativeGlobalLimitBase =
292-
NativeGlobalLimitExec(limit, child)
284+
private def effectiveLimit(rawLimit: Int): Int =
285+
if (rawLimit == -1) Int.MaxValue else rawLimit
293286

294-
override def createNativeLocalLimitExec(limit: Long, child: SparkPlan): NativeLocalLimitBase =
287+
@sparkver("3.4 / 3.5")
288+
override def getLimitAndOffset(plan: GlobalLimitExec): (Int, Int) = {
289+
(effectiveLimit(plan.limit), plan.offset)
290+
}
291+
292+
@sparkver("3.4 / 3.5")
293+
override def getLimitAndOffset(plan: TakeOrderedAndProjectExec): (Int, Int) = {
294+
(effectiveLimit(plan.limit), plan.offset)
295+
}
296+
297+
override def createNativeGlobalLimitExec(
298+
limit: Int,
299+
offset: Int,
300+
child: SparkPlan): NativeGlobalLimitBase =
301+
NativeGlobalLimitExec(limit, offset, child)
302+
303+
override def createNativeLocalLimitExec(limit: Int, child: SparkPlan): NativeLocalLimitBase =
295304
NativeLocalLimitExec(limit, child)
296305

306+
@sparkver("3.4 / 3.5")
307+
override def getLimitAndOffset(plan: CollectLimitExec): (Int, Int) = {
308+
(effectiveLimit(plan.limit), plan.offset)
309+
}
310+
297311
override def createNativeCollectLimitExec(
298312
limit: Int,
313+
offset: Int,
299314
child: SparkPlan): NativeCollectLimitBase =
300-
NativeCollectLimitExec(limit, child)
315+
NativeCollectLimitExec(limit, offset, child)
301316

302317
override def createNativeParquetInsertIntoHiveTableExec(
303318
cmd: InsertIntoHiveTable,
@@ -334,13 +349,14 @@ class ShimsImpl extends Shims with Logging {
334349
NativeSortExec(sortOrder, global, child)
335350

336351
override def createNativeTakeOrderedExec(
337-
limit: Long,
352+
limit: Int,
353+
offset: Int,
338354
sortOrder: Seq[SortOrder],
339355
child: SparkPlan): NativeTakeOrderedBase =
340-
NativeTakeOrderedExec(limit, sortOrder, child)
356+
NativeTakeOrderedExec(limit, offset, sortOrder, child)
341357

342358
override def createNativePartialTakeOrderedExec(
343-
limit: Long,
359+
limit: Int,
344360
sortOrder: Seq[SortOrder],
345361
child: SparkPlan,
346362
metrics: Map[String, SQLMetric]): NativePartialTakeOrderedBase =

spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeCollectLimitExec.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ import org.apache.spark.sql.execution.SparkPlan
2020

2121
import org.apache.auron.sparkver
2222

23-
case class NativeCollectLimitExec(limit: Int, override val child: SparkPlan)
24-
extends NativeCollectLimitBase(limit, child) {
23+
case class NativeCollectLimitExec(limit: Int, offset: Int, override val child: SparkPlan)
24+
extends NativeCollectLimitBase(limit, offset, child) {
2525

2626
@sparkver("3.2 / 3.3 / 3.4 / 3.5")
2727
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =

spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeGlobalLimitExec.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ import org.apache.spark.sql.execution.SparkPlan
2020

2121
import org.apache.auron.sparkver
2222

23-
case class NativeGlobalLimitExec(limit: Long, override val child: SparkPlan)
24-
extends NativeGlobalLimitBase(limit, child) {
23+
case class NativeGlobalLimitExec(limit: Int, offset: Int, override val child: SparkPlan)
24+
extends NativeGlobalLimitBase(limit, offset, child) {
2525

2626
@sparkver("3.2 / 3.3 / 3.4 / 3.5")
2727
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =

spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeLocalLimitExec.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ import org.apache.spark.sql.execution.SparkPlan
2020

2121
import org.apache.auron.sparkver
2222

23-
case class NativeLocalLimitExec(limit: Long, override val child: SparkPlan)
23+
case class NativeLocalLimitExec(limit: Int, override val child: SparkPlan)
2424
extends NativeLocalLimitBase(limit, child) {
2525

2626
@sparkver("3.2 / 3.3 / 3.4 / 3.5")

spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativePartialTakeOrderedExec.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import org.apache.spark.sql.execution.metric.SQLMetric
2323
import org.apache.auron.sparkver
2424

2525
case class NativePartialTakeOrderedExec(
26-
limit: Long,
26+
limit: Int,
2727
sortOrder: Seq[SortOrder],
2828
override val child: SparkPlan,
2929
override val metrics: Map[String, SQLMetric])

spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeTakeOrderedExec.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,11 @@ import org.apache.spark.sql.execution.SparkPlan
2222
import org.apache.auron.sparkver
2323

2424
case class NativeTakeOrderedExec(
25-
limit: Long,
25+
limit: Int,
26+
offset: Int,
2627
sortOrder: Seq[SortOrder],
2728
override val child: SparkPlan)
28-
extends NativeTakeOrderedBase(limit, sortOrder, child) {
29+
extends NativeTakeOrderedBase(limit, offset, sortOrder, child) {
2930

3031
@sparkver("3.2 / 3.3 / 3.4 / 3.5")
3132
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =

0 commit comments

Comments
 (0)