Skip to content

Commit ab113ee

Browse files
committed
Add task_count to context
1 parent ce600dd commit ab113ee

File tree

3 files changed

+14
-12
lines changed

3 files changed

+14
-12
lines changed

src/execution_plans/partition_isolator.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
use crate::StageExec;
21
use crate::distributed_physical_optimizer_rule::limit_tasks_err;
32
use crate::execution_plans::DistributedTaskContext;
43
use datafusion::common::{exec_err, plan_err};
@@ -191,12 +190,14 @@ impl ExecutionPlan for PartitionIsolatorExec {
191190
};
192191

193192
let task_context = DistributedTaskContext::from_ctx(&context);
194-
let stage = StageExec::from_ctx(&context)?;
195193

196194
let input_partitions = self_ready.input.output_partitioning().partition_count();
197195

198-
let partition_group =
199-
Self::partition_group(input_partitions, task_context.task_index, stage.tasks.len());
196+
let partition_group = Self::partition_group(
197+
input_partitions,
198+
task_context.task_index,
199+
task_context.task_count,
200+
);
200201

201202
// if our partition group is [7,8,9] and we are asked for parittion 1,
202203
// then look up that index in our group and execute that partition, in this

src/execution_plans/stage.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -149,13 +149,10 @@ pub struct ExecutionTask {
149149
#[derive(Debug, Clone, Default)]
150150
pub struct DistributedTaskContext {
151151
pub task_index: usize,
152+
pub task_count: usize,
152153
}
153154

154155
impl DistributedTaskContext {
155-
pub fn new(task_index: usize) -> Self {
156-
Self { task_index }
157-
}
158-
159156
pub fn from_ctx(ctx: &Arc<TaskContext>) -> Arc<Self> {
160157
ctx.session_config()
161158
.get_extension::<Self>()
@@ -335,7 +332,10 @@ impl ExecutionPlan for StageExec {
335332
.session_config()
336333
.clone()
337334
.with_extension(assigned_stage.clone())
338-
.with_extension(Arc::new(DistributedTaskContext { task_index: 0 }));
335+
.with_extension(Arc::new(DistributedTaskContext {
336+
task_index: 0,
337+
task_count: 1,
338+
}));
339339

340340
let new_ctx =
341341
SessionContext::new_with_config_rt(config, context.runtime_env().clone()).task_ctx();

src/flight_service/do_get.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,10 @@ impl ArrowFlightEndpoint {
113113
let cfg = session_state.config_mut();
114114
cfg.set_extension(Arc::clone(&stage));
115115
cfg.set_extension(Arc::new(ContextGrpcMetadata(metadata.into_headers())));
116-
cfg.set_extension(Arc::new(DistributedTaskContext::new(
117-
doget.target_task_index as usize,
118-
)));
116+
cfg.set_extension(Arc::new(DistributedTaskContext {
117+
task_index: doget.target_task_index as usize,
118+
task_count: stage.tasks.len(),
119+
}));
119120

120121
let partition_count = stage.plan.properties().partitioning.partition_count();
121122
let target_partition = doget.target_partition as usize;

0 commit comments

Comments
 (0)